TensorIO.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_IO_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 struct TensorIOFormat;
18 
19 namespace internal {
20 template <typename Tensor, std::size_t rank>
21 struct TensorPrinter;
22 }
23 
25  TensorIOFormat(const std::vector<std::string>& _separator, const std::vector<std::string>& _prefix,
26  const std::vector<std::string>& _suffix, int _precision = StreamPrecision, int _flags = 0,
27  const std::string& _tenPrefix = "", const std::string& _tenSuffix = "", const char _fill = ' ')
28  : tenPrefix(_tenPrefix),
29  tenSuffix(_tenSuffix),
30  prefix(_prefix),
31  suffix(_suffix),
32  separator(_separator),
33  fill(_fill),
34  precision(_precision),
35  flags(_flags) {
36  init_spacer();
37  }
38 
39  TensorIOFormat(int _precision = StreamPrecision, int _flags = 0, const std::string& _tenPrefix = "",
40  const std::string& _tenSuffix = "", const char _fill = ' ')
41  : tenPrefix(_tenPrefix), tenSuffix(_tenSuffix), fill(_fill), precision(_precision), flags(_flags) {
42  // default values of prefix, suffix and separator
43  prefix = {"", "["};
44  suffix = {"", "]"};
45  separator = {", ", "\n"};
46 
47  init_spacer();
48  }
49 
50  void init_spacer() {
51  if ((flags & DontAlignCols)) return;
52  spacer.resize(prefix.size());
53  spacer[0] = "";
54  int i = int(tenPrefix.length()) - 1;
55  while (i >= 0 && tenPrefix[i] != '\n') {
56  spacer[0] += ' ';
57  i--;
58  }
59 
60  for (std::size_t k = 1; k < prefix.size(); k++) {
61  int j = int(prefix[k].length()) - 1;
62  while (j >= 0 && prefix[k][j] != '\n') {
63  spacer[k] += ' ';
64  j--;
65  }
66  }
67  }
68 
69  static inline const TensorIOFormat Numpy() {
70  std::vector<std::string> prefix = {"", "["};
71  std::vector<std::string> suffix = {"", "]"};
72  std::vector<std::string> separator = {" ", "\n"};
73  return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "[", "]");
74  }
75 
76  static inline const TensorIOFormat Plain() {
77  std::vector<std::string> separator = {" ", "\n", "\n", ""};
78  std::vector<std::string> prefix = {""};
79  std::vector<std::string> suffix = {""};
80  return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "", "", ' ');
81  }
82 
83  static inline const TensorIOFormat Native() {
84  std::vector<std::string> separator = {", ", ",\n", "\n"};
85  std::vector<std::string> prefix = {"", "{"};
86  std::vector<std::string> suffix = {"", "}"};
87  return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "{", "}", ' ');
88  }
89 
90  static inline const TensorIOFormat Legacy() {
91  TensorIOFormat LegacyFormat(StreamPrecision, 0, "", "", ' ');
92  LegacyFormat.legacy_bit = true;
93  return LegacyFormat;
94  }
95 
96  std::string tenPrefix;
97  std::string tenSuffix;
98  std::vector<std::string> prefix;
99  std::vector<std::string> suffix;
100  std::vector<std::string> separator;
101  char fill;
103  int flags;
104  std::vector<std::string> spacer{};
105  bool legacy_bit = false;
106 };
107 
108 template <typename T, int Layout, int rank>
110 // specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
111 template <typename T, int rank>
112 class TensorWithFormat<T, RowMajor, rank> {
113  public:
114  TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
115 
116  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank>& wf) {
117  // Evaluate the expression if needed
119  TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
120  Evaluator tensor(eval, DefaultDevice());
121  tensor.evalSubExprsIfNeeded(NULL);
122  internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
123  // Cleanup.
124  tensor.cleanup();
125  return os;
126  }
127 
128  protected:
131 };
132 
133 template <typename T, int rank>
134 class TensorWithFormat<T, ColMajor, rank> {
135  public:
136  TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
137 
138  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank>& wf) {
139  // Switch to RowMajor storage and print afterwards
140  typedef typename T::Index IndexType;
141  std::array<IndexType, rank> shuffle;
142  std::array<IndexType, rank> id;
143  std::iota(id.begin(), id.end(), IndexType(0));
144  std::copy(id.begin(), id.end(), shuffle.rbegin());
145  auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
146 
147  // Evaluate the expression if needed
148  typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
149  TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
150  Evaluator tensor(eval, DefaultDevice());
151  tensor.evalSubExprsIfNeeded(NULL);
152  internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
153  // Cleanup.
154  tensor.cleanup();
155  return os;
156  }
157 
158  protected:
161 };
162 
163 template <typename T>
165  public:
166  TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
167 
168  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0>& wf) {
169  // Evaluate the expression if needed
171  TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
172  Evaluator tensor(eval, DefaultDevice());
173  tensor.evalSubExprsIfNeeded(NULL);
174  internal::TensorPrinter<Evaluator, 0>::run(os, tensor, wf.t_format);
175  // Cleanup.
176  tensor.cleanup();
177  return os;
178  }
179 
180  protected:
183 };
184 
185 namespace internal {
186 template <typename Tensor, std::size_t rank>
187 struct TensorPrinter {
188  static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
189  typedef std::remove_const_t<typename Tensor::Scalar> Scalar;
190  typedef typename Tensor::Index IndexType;
191  static const int layout = Tensor::Layout;
192  // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
193  // (dim(1)*dim(2)*...*dim(rank-1)).
194  if (fmt.legacy_bit) {
195  const IndexType total_size = internal::array_prod(_t.dimensions());
196  if (total_size > 0) {
197  const IndexType first_dim = Eigen::internal::array_get<0>(_t.dimensions());
199  total_size / first_dim);
200  s << matrix;
201  return;
202  }
203  }
204 
205  eigen_assert(layout == RowMajor);
206  typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
207  is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
208  int,
209  std::conditional_t<is_same<Scalar, std::complex<char> >::value ||
210  is_same<Scalar, std::complex<unsigned char> >::value ||
211  is_same<Scalar, std::complex<numext::int8_t> >::value ||
212  is_same<Scalar, std::complex<numext::uint8_t> >::value,
213  std::complex<int>, const Scalar&>> PrintType;
214 
215  const IndexType total_size = array_prod(_t.dimensions());
216 
217  std::streamsize explicit_precision;
218  if (fmt.precision == StreamPrecision) {
219  explicit_precision = 0;
220  } else if (fmt.precision == FullPrecision) {
222  explicit_precision = 0;
223  } else {
224  explicit_precision = significant_decimals_impl<Scalar>::run();
225  }
226  } else {
227  explicit_precision = fmt.precision;
228  }
229 
230  std::streamsize old_precision = 0;
231  if (explicit_precision) old_precision = s.precision(explicit_precision);
232 
233  IndexType width = 0;
234 
235  bool align_cols = !(fmt.flags & DontAlignCols);
236  if (align_cols) {
237  // compute the largest width
238  for (IndexType i = 0; i < total_size; i++) {
239  std::stringstream sstr;
240  sstr.copyfmt(s);
241  sstr << static_cast<PrintType>(_t.data()[i]);
242  width = std::max<IndexType>(width, IndexType(sstr.str().length()));
243  }
244  }
245  std::streamsize old_width = s.width();
246  char old_fill_character = s.fill();
247 
248  s << fmt.tenPrefix;
249  for (IndexType i = 0; i < total_size; i++) {
250  std::array<bool, rank> is_at_end{};
251  std::array<bool, rank> is_at_begin{};
252 
253  // is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
254  for (std::size_t k = 0; k < rank; k++) {
255  if ((i + 1) % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
256  std::multiplies<IndexType>())) ==
257  0) {
258  is_at_end[k] = true;
259  }
260  }
261 
262  // is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
263  for (std::size_t k = 0; k < rank; k++) {
264  if (i % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
265  std::multiplies<IndexType>())) ==
266  0) {
267  is_at_begin[k] = true;
268  }
269  }
270 
271  // do we have a line break?
272  bool is_at_begin_after_newline = false;
273  for (std::size_t k = 0; k < rank; k++) {
274  if (is_at_begin[k]) {
275  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
276  if (fmt.separator[separator_index].find('\n') != std::string::npos) {
277  is_at_begin_after_newline = true;
278  }
279  }
280  }
281 
282  bool is_at_end_before_newline = false;
283  for (std::size_t k = 0; k < rank; k++) {
284  if (is_at_end[k]) {
285  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
286  if (fmt.separator[separator_index].find('\n') != std::string::npos) {
287  is_at_end_before_newline = true;
288  }
289  }
290  }
291 
292  std::stringstream suffix, prefix, separator;
293  for (std::size_t k = 0; k < rank; k++) {
294  std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
295  if (is_at_end[k]) {
296  suffix << fmt.suffix[suffix_index];
297  }
298  }
299  for (std::size_t k = 0; k < rank; k++) {
300  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
301  if (is_at_end[k] &&
302  (!is_at_end_before_newline || fmt.separator[separator_index].find('\n') != std::string::npos)) {
303  separator << fmt.separator[separator_index];
304  }
305  }
306  for (std::size_t k = 0; k < rank; k++) {
307  std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
308  if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
309  prefix << fmt.spacer[spacer_index];
310  }
311  }
312  for (int k = rank - 1; k >= 0; k--) {
313  std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
314  if (is_at_begin[k]) {
315  prefix << fmt.prefix[prefix_index];
316  }
317  }
318 
319  s << prefix.str();
320  if (width) {
321  s.fill(fmt.fill);
322  s.width(width);
323  s << std::right;
324  }
325  s << _t.data()[i];
326  s << suffix.str();
327  if (i < total_size - 1) {
328  s << separator.str();
329  }
330  }
331  s << fmt.tenSuffix;
332  if (explicit_precision) s.precision(old_precision);
333  if (width) {
334  s.fill(old_fill_character);
335  s.width(old_width);
336  }
337  }
338 };
339 
340 template <typename Tensor>
341 struct TensorPrinter<Tensor, 0> {
342  static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
343  typedef typename Tensor::Scalar Scalar;
344 
345  std::streamsize explicit_precision;
346  if (fmt.precision == StreamPrecision) {
347  explicit_precision = 0;
348  } else if (fmt.precision == FullPrecision) {
350  explicit_precision = 0;
351  } else {
352  explicit_precision = significant_decimals_impl<Scalar>::run();
353  }
354  } else {
355  explicit_precision = fmt.precision;
356  }
357 
358  std::streamsize old_precision = 0;
359  if (explicit_precision) old_precision = s.precision(explicit_precision);
360 
361  s << fmt.tenPrefix << _t.coeff(0) << fmt.tenSuffix;
362  if (explicit_precision) s.precision(old_precision);
363  }
364 };
365 
366 } // end namespace internal
367 template <typename T>
368 std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
369  s << t.format(TensorIOFormat::Plain());
370  return s;
371 }
372 } // end namespace Eigen
373 
374 #endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
int i
#define eigen_assert(x)
The tensor base class.
friend std::ostream & operator<<(std::ostream &os, const TensorWithFormat< T, ColMajor, 0 > &wf)
Definition: TensorIO.h:168
TensorWithFormat(const T &tensor, const TensorIOFormat &format)
Definition: TensorIO.h:166
TensorWithFormat(const T &tensor, const TensorIOFormat &format)
Definition: TensorIO.h:136
friend std::ostream & operator<<(std::ostream &os, const TensorWithFormat< T, ColMajor, rank > &wf)
Definition: TensorIO.h:138
friend std::ostream & operator<<(std::ostream &os, const TensorWithFormat< T, RowMajor, rank > &wf)
Definition: TensorIO.h:116
TensorWithFormat(const T &tensor, const TensorIOFormat &format)
Definition: TensorIO.h:114
The tensor class.
Definition: Tensor.h:67
static constexpr int Layout
Definition: Tensor.h:84
Scalar_ Scalar
Definition: Tensor.h:74
internal::traits< Self >::Index Index
Definition: Tensor.h:73
const Dimensions & dimensions() const
Definition: Tensor.h:103
Scalar * data()
Definition: Tensor.h:105
static const lastp1_t end
constexpr auto array_prod(const array< T, N > &arr) -> decltype(array_reduce< product_op, T, N >(arr, static_cast< T >(1)))
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
std::ostream & operator<<(std::ostream &s, const DiagonalBase< Derived > &m)
StreamPrecision
A cost model used to limit the number of threads used for evaluating tensor expression.
std::vector< std::string > separator
Definition: TensorIO.h:100
std::vector< std::string > prefix
Definition: TensorIO.h:98
std::vector< std::string > spacer
Definition: TensorIO.h:104
static const TensorIOFormat Numpy()
Definition: TensorIO.h:69
static const TensorIOFormat Legacy()
Definition: TensorIO.h:90
static const TensorIOFormat Plain()
Definition: TensorIO.h:76
std::vector< std::string > suffix
Definition: TensorIO.h:99
TensorIOFormat(int _precision=StreamPrecision, int _flags=0, const std::string &_tenPrefix="", const std::string &_tenSuffix="", const char _fill=' ')
Definition: TensorIO.h:39
std::string tenPrefix
Definition: TensorIO.h:96
std::string tenSuffix
Definition: TensorIO.h:97
static const TensorIOFormat Native()
Definition: TensorIO.h:83
TensorIOFormat(const std::vector< std::string > &_separator, const std::vector< std::string > &_prefix, const std::vector< std::string > &_suffix, int _precision=StreamPrecision, int _flags=0, const std::string &_tenPrefix="", const std::string &_tenSuffix="", const char _fill=' ')
Definition: TensorIO.h:25
std::ptrdiff_t j