11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
27 template<
typename Dims,
typename XprType>
28 struct traits<TensorTraceOp<Dims, XprType> > :
public traits<XprType>
30 typedef typename XprType::Scalar Scalar;
32 typedef typename XprTraits::StorageKind StorageKind;
33 typedef typename XprTraits::Index
Index;
34 typedef typename XprType::Nested Nested;
35 typedef std::remove_reference_t<Nested> Nested_;
36 static constexpr
int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
37 static constexpr
int Layout = XprTraits::Layout;
40 template<
typename Dims,
typename XprType>
41 struct eval<TensorTraceOp<Dims, XprType>,
Eigen::Dense>
43 typedef const TensorTraceOp<Dims, XprType>& type;
46 template<
typename Dims,
typename XprType>
47 struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type>
49 typedef TensorTraceOp<Dims, XprType> type;
55 template<
typename Dims,
typename XprType>
59 typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar
Scalar;
62 typedef typename Eigen::internal::nested<TensorTraceOp>::type
Nested;
63 typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind
StorageKind;
64 typedef typename Eigen::internal::traits<TensorTraceOp>::Index
Index;
83 template<
typename Dims,
typename ArgType,
typename Device>
87 static constexpr
int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
88 static constexpr
int NumReducedDims = internal::array_size<Dims>::value;
89 static constexpr
int NumOutputDims = NumInputDims - NumReducedDims;
95 static constexpr
int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
114 : m_impl(op.expression(), device), m_traceDim(1),
m_device(device)
118 EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)), YOU_MADE_A_PROGRAMMING_MISTAKE);
120 for (
int i = 0;
i < NumInputDims; ++
i) {
121 m_reduced[
i] =
false;
124 const Dims& op_dims = op.
dims();
125 for (
int i = 0;
i < NumReducedDims; ++
i) {
128 m_reduced[op_dims[
i]] =
true;
132 int num_distinct_reduce_dims = 0;
133 for (
int i = 0;
i < NumInputDims; ++
i) {
135 ++num_distinct_reduce_dims;
140 eigen_assert(num_distinct_reduce_dims == NumReducedDims);
145 int output_index = 0;
146 int reduced_index = 0;
147 for (
int i = 0;
i < NumInputDims; ++
i) {
149 m_reducedDims[reduced_index] = input_dims[
i];
150 if (reduced_index > 0) {
152 eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
157 m_dimensions[output_index] = input_dims[
i];
162 if (NumReducedDims != 0) {
163 m_traceDim = m_reducedDims[0];
167 if (NumOutputDims > 0) {
169 m_outputStrides[0] = 1;
170 for (
int i = 1;
i < NumOutputDims; ++
i) {
171 m_outputStrides[
i] = m_outputStrides[
i - 1] * m_dimensions[
i - 1];
175 m_outputStrides.back() = 1;
176 for (
int i = NumOutputDims - 2;
i >= 0; --
i) {
177 m_outputStrides[
i] = m_outputStrides[
i + 1] * m_dimensions[
i + 1];
183 if (NumInputDims > 0) {
186 input_strides[0] = 1;
187 for (
int i = 1;
i < NumInputDims; ++
i) {
188 input_strides[
i] = input_strides[
i - 1] * input_dims[
i - 1];
192 input_strides.back() = 1;
193 for (
int i = NumInputDims - 2;
i >= 0; --
i) {
194 input_strides[
i] = input_strides[
i + 1] * input_dims[
i + 1];
200 for (
int i = 0;
i < NumInputDims; ++
i) {
202 m_reducedStrides[reduced_index] = input_strides[
i];
206 m_preservedStrides[output_index] = input_strides[
i];
218 m_impl.evalSubExprsIfNeeded(NULL);
230 Index index_stride = 0;
231 for (
int i = 0;
i < NumReducedDims; ++
i) {
232 index_stride += m_reducedStrides[
i];
237 if (NumOutputDims != 0)
238 cur_index = firstInput(index);
239 for (
Index i = 0;
i < m_traceDim; ++
i) {
240 result += m_impl.coeff(cur_index);
241 cur_index += index_stride;
247 template<
int LoadMode>
255 PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
262 Index startInput = 0;
264 for (
int i = NumOutputDims - 1;
i > 0; --
i) {
265 const Index idx = index / m_outputStrides[
i];
266 startInput += idx * m_preservedStrides[
i];
267 index -= idx * m_outputStrides[
i];
269 startInput += index * m_preservedStrides[0];
272 for (
int i = 0;
i < NumOutputDims - 1; ++
i) {
273 const Index idx = index / m_outputStrides[
i];
274 startInput += idx * m_preservedStrides[
i];
275 index -= idx * m_outputStrides[
i];
277 startInput += index * m_preservedStrides[NumOutputDims - 1];
#define EIGEN_DEVICE_FUNC
#define EIGEN_ONLY_USED_FOR_DEBUG(x)
#define EIGEN_STATIC_ASSERT(X, MSG)
Eigen::internal::nested< TensorTraceOp >::type Nested
TensorTraceOp(const XprType &expr, const Dims &dims)
const Dims & dims() const
XprType::CoeffReturnType CoeffReturnType
const internal::remove_all_t< typename XprType::Nested > & expression() const
Eigen::internal::traits< TensorTraceOp >::StorageKind StorageKind
Eigen::internal::traits< TensorTraceOp >::Scalar Scalar
Eigen::NumTraits< Scalar >::Real RealScalar
Eigen::internal::traits< TensorTraceOp >::Index Index
typename remove_all< T >::type remove_all_t
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
internal::packet_traits< Scalar >::type type
internal::TensorBlockNotImplemented TensorBlock
Index firstInput(Index index) const
PacketReturnType packet(Index index) const
const Dimensions & dimensions() const
array< Index, NumOutputDims > m_preservedStrides
bool evalSubExprsIfNeeded(EvaluatorPointerType)
array< Index, NumReducedDims > m_reducedDims
array< bool, NumInputDims > m_reduced
PacketType< CoeffReturnType, Device >::type PacketReturnType
Storage::Type EvaluatorPointerType
TensorTraceOp< Dims, ArgType > XprType
TensorEvaluator(const XprType &op, const Device &device)
array< Index, NumReducedDims > m_reducedStrides
const Device EIGEN_DEVICE_REF m_device
XprType::CoeffReturnType CoeffReturnType
array< Index, NumOutputDims > m_outputStrides
TensorEvaluator< ArgType, Device > m_impl
StorageMemory< CoeffReturnType, Device > Storage
DSizes< Index, NumOutputDims > Dimensions
CoeffReturnType coeff(Index index) const
A cost model used to limit the number of threads used for evaluating tensor expression.
const Dimensions & dimensions() const
static constexpr int Layout
const Device EIGEN_DEVICE_REF m_device
CoeffReturnType coeff(Index index) const
static constexpr int PacketSize