TensorExpr.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
32 namespace internal {
33 template<typename NullaryOp, typename XprType>
34 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
35  : traits<XprType>
36 {
37  typedef traits<XprType> XprTraits;
38  typedef typename XprType::Scalar Scalar;
39  typedef typename XprType::Nested XprTypeNested;
40  typedef std::remove_reference_t<XprTypeNested> XprTypeNested_;
41  static constexpr int NumDimensions = XprTraits::NumDimensions;
42  static constexpr int Layout = XprTraits::Layout;
43  typedef typename XprTraits::PointerType PointerType;
44  enum {
45  Flags = 0
46  };
47 };
48 
49 } // end namespace internal
50 
51 
52 
53 template<typename NullaryOp, typename XprType>
54 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
55 {
56  public:
57  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
59  typedef typename XprType::CoeffReturnType CoeffReturnType;
61  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
62  typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
63 
64  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
65  : m_xpr(xpr), m_functor(func) {}
66 
69  nestedExpression() const { return m_xpr; }
70 
72  const NullaryOp& functor() const { return m_functor; }
73 
74  protected:
75  typename XprType::Nested m_xpr;
76  const NullaryOp m_functor;
77 };
78 
79 
80 
81 namespace internal {
82 template<typename UnaryOp, typename XprType>
83 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
84  : traits<XprType>
85 {
86  // TODO(phli): Add InputScalar, InputPacket. Check references to
87  // current Scalar/Packet to see if the intent is Input or Output.
88  typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
89  typedef traits<XprType> XprTraits;
90  typedef typename XprType::Nested XprTypeNested;
91  typedef std::remove_reference_t<XprTypeNested> XprTypeNested_;
92  static constexpr int NumDimensions = XprTraits::NumDimensions;
93  static constexpr int Layout = XprTraits::Layout;
94  typedef typename TypeConversion<Scalar,
95  typename XprTraits::PointerType
96  >::type
97  PointerType;
98 };
99 
100 template<typename UnaryOp, typename XprType>
101 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
102 {
103  typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
104 };
105 
106 template<typename UnaryOp, typename XprType>
107 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
108 {
109  typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
110 };
111 
112 } // end namespace internal
113 
114 
115 
116 template<typename UnaryOp, typename XprType>
117 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
118 {
119  public:
120  // TODO(phli): Add InputScalar, InputPacket. Check references to
121  // current Scalar/Packet to see if the intent is Input or Output.
122  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
125  typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
126  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
127  typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
128 
129  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
130  : m_xpr(xpr), m_functor(func) {}
131 
133  const UnaryOp& functor() const { return m_functor; }
134 
138  nestedExpression() const { return m_xpr; }
139 
140  protected:
141  typename XprType::Nested m_xpr;
142  const UnaryOp m_functor;
143 };
144 
145 
146 namespace internal {
147 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
148 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
149 {
150  // Type promotion to handle the case where the types of the lhs and the rhs
151  // are different.
152  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
153  // current Scalar/Packet to see if the intent is Inputs or Output.
154  typedef typename result_of<
155  BinaryOp(typename LhsXprType::Scalar,
156  typename RhsXprType::Scalar)>::type Scalar;
157  typedef traits<LhsXprType> XprTraits;
158  typedef typename promote_storage_type<
160  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
161  typedef typename promote_index_type<
162  typename traits<LhsXprType>::Index,
163  typename traits<RhsXprType>::Index>::type Index;
164  typedef typename LhsXprType::Nested LhsNested;
165  typedef typename RhsXprType::Nested RhsNested;
166  typedef std::remove_reference_t<LhsNested> LhsNested_;
167  typedef std::remove_reference_t<RhsNested> RhsNested_;
168  static constexpr int NumDimensions = XprTraits::NumDimensions;
169  static constexpr int Layout = XprTraits::Layout;
170  typedef typename TypeConversion<Scalar,
171  std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
174  >::type
175  PointerType;
176  enum {
177  Flags = 0
178  };
179 };
180 
181 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
182 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
183 {
184  typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
185 };
186 
187 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
188 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
189 {
190  typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
191 };
192 
193 } // end namespace internal
194 
195 
196 
197 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
198 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
199 {
200  public:
201  // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
202  // current Scalar/Packet to see if the intent is Inputs or Output.
203  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
206  typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
207  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
208  typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
209 
210  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
211  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
212 
214  const BinaryOp& functor() const { return m_functor; }
215 
219  lhsExpression() const { return m_lhs_xpr; }
220 
223  rhsExpression() const { return m_rhs_xpr; }
224 
225  protected:
226  typename LhsXprType::Nested m_lhs_xpr;
227  typename RhsXprType::Nested m_rhs_xpr;
229 };
230 
231 
232 namespace internal {
233 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
234 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
235 {
236  // Type promotion to handle the case where the types of the args are different.
237  typedef typename result_of<
238  TernaryOp(typename Arg1XprType::Scalar,
239  typename Arg2XprType::Scalar,
240  typename Arg3XprType::Scalar)>::type Scalar;
241  typedef traits<Arg1XprType> XprTraits;
242  typedef typename traits<Arg1XprType>::StorageKind StorageKind;
243  typedef typename traits<Arg1XprType>::Index Index;
244  typedef typename Arg1XprType::Nested Arg1Nested;
245  typedef typename Arg2XprType::Nested Arg2Nested;
246  typedef typename Arg3XprType::Nested Arg3Nested;
247  typedef std::remove_reference_t<Arg1Nested> Arg1Nested_;
248  typedef std::remove_reference_t<Arg2Nested> Arg2Nested_;
249  typedef std::remove_reference_t<Arg3Nested> Arg3Nested_;
250  static constexpr int NumDimensions = XprTraits::NumDimensions;
251  static constexpr int Layout = XprTraits::Layout;
252  typedef typename TypeConversion<Scalar,
253  std::conditional_t<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val,
256  >::type
257  PointerType;
258  enum {
259  Flags = 0
260  };
261 };
262 
263 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
264 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
265 {
266  typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
267 };
268 
269 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
270 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
271 {
272  typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
273 };
274 
275 } // end namespace internal
276 
277 
278 
279 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
280 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
281 {
282  public:
283  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
286  typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
287  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
288  typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
289 
290  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
291  : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
292 
294  const TernaryOp& functor() const { return m_functor; }
295 
299  arg1Expression() const { return m_arg1_xpr; }
300 
303  arg2Expression() const { return m_arg2_xpr; }
304 
307  arg3Expression() const { return m_arg3_xpr; }
308 
309  protected:
310  typename Arg1XprType::Nested m_arg1_xpr;
311  typename Arg2XprType::Nested m_arg2_xpr;
312  typename Arg3XprType::Nested m_arg3_xpr;
313  const TernaryOp m_functor;
314 };
315 
316 
317 namespace internal {
318 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
319 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
320  : traits<ThenXprType>
321 {
322  typedef typename traits<ThenXprType>::Scalar Scalar;
323  typedef traits<ThenXprType> XprTraits;
324  typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
325  typename traits<ElseXprType>::StorageKind>::ret StorageKind;
326  typedef typename promote_index_type<typename traits<ElseXprType>::Index,
327  typename traits<ThenXprType>::Index>::type Index;
328  typedef typename IfXprType::Nested IfNested;
329  typedef typename ThenXprType::Nested ThenNested;
330  typedef typename ElseXprType::Nested ElseNested;
331  static constexpr int NumDimensions = XprTraits::NumDimensions;
332  static constexpr int Layout = XprTraits::Layout;
333  typedef std::conditional_t<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val,
335  typename traits<ElseXprType>::PointerType> PointerType;
336 };
337 
338 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
339 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
340 {
341  typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
342 };
343 
344 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
345 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
346 {
347  typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
348 };
349 
350 } // end namespace internal
351 
352 
353 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
354 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
355 {
356  public:
357  typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
359  typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
360  typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
361  typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
362  typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
363  typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
364 
366  TensorSelectOp(const IfXprType& a_condition,
367  const ThenXprType& a_then,
368  const ElseXprType& a_else)
369  : m_condition(a_condition), m_then(a_then), m_else(a_else)
370  { }
371 
373  const IfXprType& ifExpression() const { return m_condition; }
374 
376  const ThenXprType& thenExpression() const { return m_then; }
377 
379  const ElseXprType& elseExpression() const { return m_else; }
380 
381  protected:
382  typename IfXprType::Nested m_condition;
383  typename ThenXprType::Nested m_then;
384  typename ElseXprType::Nested m_else;
385 };
386 
387 
388 } // end namespace Eigen
389 
390 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
#define EIGEN_DEVICE_FUNC
The tensor base class.
Eigen::internal::traits< TensorCwiseBinaryOp >::Index Index
Definition: TensorExpr.h:208
RhsXprType::Nested m_rhs_xpr
Definition: TensorExpr.h:227
const internal::remove_all_t< typename LhsXprType::Nested > & lhsExpression() const
Definition: TensorExpr.h:219
const BinaryOp m_functor
Definition: TensorExpr.h:228
Eigen::internal::nested< TensorCwiseBinaryOp >::type Nested
Definition: TensorExpr.h:206
const internal::remove_all_t< typename RhsXprType::Nested > & rhsExpression() const
Definition: TensorExpr.h:223
Eigen::internal::traits< TensorCwiseBinaryOp >::Scalar Scalar
Definition: TensorExpr.h:203
const BinaryOp & functor() const
Definition: TensorExpr.h:214
Eigen::internal::traits< TensorCwiseBinaryOp >::StorageKind StorageKind
Definition: TensorExpr.h:207
LhsXprType::Nested m_lhs_xpr
Definition: TensorExpr.h:226
TensorCwiseBinaryOp(const LhsXprType &lhs, const RhsXprType &rhs, const BinaryOp &func=BinaryOp())
Definition: TensorExpr.h:210
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorExpr.h:204
const NullaryOp & functor() const
Definition: TensorExpr.h:72
const internal::remove_all_t< typename XprType::Nested > & nestedExpression() const
Definition: TensorExpr.h:69
TensorCwiseNullaryOp< NullaryOp, XprType > Nested
Definition: TensorExpr.h:60
Eigen::internal::traits< TensorCwiseNullaryOp >::Scalar Scalar
Definition: TensorExpr.h:57
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorExpr.h:58
Eigen::internal::traits< TensorCwiseNullaryOp >::StorageKind StorageKind
Definition: TensorExpr.h:61
XprType::CoeffReturnType CoeffReturnType
Definition: TensorExpr.h:59
TensorCwiseNullaryOp(const XprType &xpr, const NullaryOp &func=NullaryOp())
Definition: TensorExpr.h:64
const NullaryOp m_functor
Definition: TensorExpr.h:76
Eigen::internal::traits< TensorCwiseNullaryOp >::Index Index
Definition: TensorExpr.h:62
XprType::Nested m_xpr
Definition: TensorExpr.h:75
Eigen::internal::traits< TensorCwiseTernaryOp >::StorageKind StorageKind
Definition: TensorExpr.h:287
Arg3XprType::Nested m_arg3_xpr
Definition: TensorExpr.h:312
const internal::remove_all_t< typename Arg3XprType::Nested > & arg3Expression() const
Definition: TensorExpr.h:307
TensorCwiseTernaryOp(const Arg1XprType &arg1, const Arg2XprType &arg2, const Arg3XprType &arg3, const TernaryOp &func=TernaryOp())
Definition: TensorExpr.h:290
const TernaryOp & functor() const
Definition: TensorExpr.h:294
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorExpr.h:284
const internal::remove_all_t< typename Arg2XprType::Nested > & arg2Expression() const
Definition: TensorExpr.h:303
const internal::remove_all_t< typename Arg1XprType::Nested > & arg1Expression() const
Definition: TensorExpr.h:299
Arg2XprType::Nested m_arg2_xpr
Definition: TensorExpr.h:311
Eigen::internal::traits< TensorCwiseTernaryOp >::Index Index
Definition: TensorExpr.h:288
Eigen::internal::traits< TensorCwiseTernaryOp >::Scalar Scalar
Definition: TensorExpr.h:283
const TernaryOp m_functor
Definition: TensorExpr.h:313
Arg1XprType::Nested m_arg1_xpr
Definition: TensorExpr.h:310
Eigen::internal::nested< TensorCwiseTernaryOp >::type Nested
Definition: TensorExpr.h:286
Eigen::internal::traits< TensorCwiseUnaryOp >::StorageKind StorageKind
Definition: TensorExpr.h:126
XprType::Nested m_xpr
Definition: TensorExpr.h:141
const UnaryOp m_functor
Definition: TensorExpr.h:142
TensorCwiseUnaryOp(const XprType &xpr, const UnaryOp &func=UnaryOp())
Definition: TensorExpr.h:129
Eigen::internal::traits< TensorCwiseUnaryOp >::Scalar Scalar
Definition: TensorExpr.h:122
const internal::remove_all_t< typename XprType::Nested > & nestedExpression() const
Definition: TensorExpr.h:138
Eigen::internal::traits< TensorCwiseUnaryOp >::Index Index
Definition: TensorExpr.h:127
Eigen::internal::nested< TensorCwiseUnaryOp >::type Nested
Definition: TensorExpr.h:125
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorExpr.h:123
const UnaryOp & functor() const
Definition: TensorExpr.h:133
const IfXprType & ifExpression() const
Definition: TensorExpr.h:373
ThenXprType::Nested m_then
Definition: TensorExpr.h:383
const ThenXprType & thenExpression() const
Definition: TensorExpr.h:376
Eigen::internal::traits< TensorSelectOp >::Index Index
Definition: TensorExpr.h:363
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorExpr.h:358
const ElseXprType & elseExpression() const
Definition: TensorExpr.h:379
Eigen::internal::nested< TensorSelectOp >::type Nested
Definition: TensorExpr.h:361
Eigen::internal::traits< TensorSelectOp >::StorageKind StorageKind
Definition: TensorExpr.h:362
TensorSelectOp(const IfXprType &a_condition, const ThenXprType &a_then, const ElseXprType &a_else)
Definition: TensorExpr.h:366
Eigen::internal::traits< TensorSelectOp >::Scalar Scalar
Definition: TensorExpr.h:357
IfXprType::Nested m_condition
Definition: TensorExpr.h:382
ElseXprType::Nested m_else
Definition: TensorExpr.h:384
internal::promote_storage_type< typename ThenXprType::CoeffReturnType, typename ElseXprType::CoeffReturnType >::ret CoeffReturnType
Definition: TensorExpr.h:360
typename remove_all< T >::type remove_all_t
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index