TensorDevice.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
29 template <typename ExpressionType, typename DeviceType> class TensorDevice {
30  public:
31  TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
32 
34 
35  template<typename OtherDerived>
36  EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
38  Assign assign(m_expression, other);
39  internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
40  return *this;
41  }
42 
43  template<typename OtherDerived>
44  EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
45  typedef typename OtherDerived::Scalar Scalar;
46  typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
47  Sum sum(m_expression, other);
49  Assign assign(m_expression, sum);
50  internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
51  return *this;
52  }
53 
54  template<typename OtherDerived>
55  EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
56  typedef typename OtherDerived::Scalar Scalar;
57  typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
58  Difference difference(m_expression, other);
60  Assign assign(m_expression, difference);
61  internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
62  return *this;
63  }
64 
65  protected:
66  const DeviceType& m_device;
67  ExpressionType& m_expression;
68 };
69 
84 template <typename ExpressionType, typename DeviceType, typename DoneCallback>
86  public:
87  TensorAsyncDevice(const DeviceType& device, ExpressionType& expression,
88  DoneCallback done)
89  : m_device(device), m_expression(expression), m_done(std::move(done)) {}
90 
91  template <typename OtherDerived>
92  EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
94  typedef internal::TensorExecutor<const Assign, DeviceType> Executor;
95 
96  Assign assign(m_expression, other);
97  Executor::run(assign, m_device);
98  m_done();
99 
100  return *this;
101  }
102 
103  protected:
104  const DeviceType& m_device;
105  ExpressionType& m_expression;
106  DoneCallback m_done;
107 };
108 
109 
110 #ifdef EIGEN_USE_THREADS
111 template <typename ExpressionType, typename DoneCallback>
112 class TensorAsyncDevice<ExpressionType, ThreadPoolDevice, DoneCallback> {
113  public:
114  TensorAsyncDevice(const ThreadPoolDevice& device, ExpressionType& expression,
115  DoneCallback done)
116  : m_device(device), m_expression(expression), m_done(std::move(done)) {}
117 
118  template <typename OtherDerived>
119  EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
120  typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
121  typedef internal::TensorAsyncExecutor<const Assign, ThreadPoolDevice, DoneCallback> Executor;
122 
123  // WARNING: After assignment 'm_done' callback will be in undefined state.
124  Assign assign(m_expression, other);
125  Executor::runAsync(assign, m_device, std::move(m_done));
126 
127  return *this;
128  }
129 
130  protected:
131  const ThreadPoolDevice& m_device;
132  ExpressionType& m_expression;
133  DoneCallback m_done;
134 };
135 #endif
136 
137 } // end namespace Eigen
138 
139 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
#define EIGEN_DEFAULT_COPY_CONSTRUCTOR(CLASS)
Pseudo expression providing an operator = that will evaluate its argument asynchronously on the speci...
Definition: TensorDevice.h:85
TensorAsyncDevice(const DeviceType &device, ExpressionType &expression, DoneCallback done)
Definition: TensorDevice.h:87
const DeviceType & m_device
Definition: TensorDevice.h:104
TensorAsyncDevice & operator=(const OtherDerived &other)
Definition: TensorDevice.h:92
ExpressionType & m_expression
Definition: TensorDevice.h:105
Pseudo expression providing an operator = that will evaluate its argument on the specified computing ...
Definition: TensorDevice.h:29
TensorDevice & operator-=(const OtherDerived &other)
Definition: TensorDevice.h:55
TensorDevice(const DeviceType &device, ExpressionType &expression)
Definition: TensorDevice.h:31
TensorDevice & operator+=(const OtherDerived &other)
Definition: TensorDevice.h:44
const DeviceType & m_device
Definition: TensorDevice.h:66
TensorDevice & operator=(const OtherDerived &other)
Definition: TensorDevice.h:36
ExpressionType & m_expression
Definition: TensorDevice.h:67
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend