Symmetry.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2013 Christian Seiler <christian@iwakd.de>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSORSYMMETRY_SYMMETRY_H
11 #define EIGEN_CXX11_TENSORSYMMETRY_SYMMETRY_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 enum {
18  NegationFlag = 0x01,
19  ConjugationFlag = 0x02
20 };
21 
22 enum {
25  GlobalZeroFlag = 0x03
26 };
27 
28 namespace internal {
29 
30 template<std::size_t NumIndices, typename... Sym> struct tensor_symmetry_pre_analysis;
31 template<std::size_t NumIndices, typename... Sym> struct tensor_static_symgroup;
32 template<bool instantiate, std::size_t NumIndices, typename... Sym> struct tensor_static_symgroup_if;
33 template<typename Tensor_> struct tensor_symmetry_calculate_flags;
34 template<typename Tensor_> struct tensor_symmetry_assign_value;
35 template<typename... Sym> struct tensor_symmetry_num_indices;
36 
37 } // end namespace internal
38 
39 template<int One_, int Two_>
40 struct Symmetry
41 {
42  static_assert(One_ != Two_, "Symmetries must cover distinct indices.");
43  constexpr static int One = One_;
44  constexpr static int Two = Two_;
45  constexpr static int Flags = 0;
46 };
47 
48 template<int One_, int Two_>
50 {
51  static_assert(One_ != Two_, "Symmetries must cover distinct indices.");
52  constexpr static int One = One_;
53  constexpr static int Two = Two_;
54  constexpr static int Flags = NegationFlag;
55 };
56 
57 template<int One_, int Two_>
59 {
60  static_assert(One_ != Two_, "Symmetries must cover distinct indices.");
61  constexpr static int One = One_;
62  constexpr static int Two = Two_;
63  constexpr static int Flags = ConjugationFlag;
64 };
65 
66 template<int One_, int Two_>
68 {
69  static_assert(One_ != Two_, "Symmetries must cover distinct indices.");
70  constexpr static int One = One_;
71  constexpr static int Two = Two_;
72  constexpr static int Flags = ConjugationFlag | NegationFlag;
73 };
74 
88 class DynamicSGroup;
89 
100 template<typename... Gen>
102 
122 template<typename... Gen>
123 class StaticSGroup;
124 
137 template<typename... Gen>
138 class SGroup : public internal::tensor_symmetry_pre_analysis<internal::tensor_symmetry_num_indices<Gen...>::value, Gen...>::root_type
139 {
140  public:
141  constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
142  typedef typename internal::tensor_symmetry_pre_analysis<NumIndices, Gen...>::root_type Base;
143 
144  // make standard constructors + assignment operators public
145  inline SGroup() : Base() { }
146  inline SGroup(const SGroup<Gen...>& other) : Base(other) { }
147  inline SGroup(SGroup<Gen...>&& other) : Base(other) { }
148  inline SGroup<Gen...>& operator=(const SGroup<Gen...>& other) { Base::operator=(other); return *this; }
149  inline SGroup<Gen...>& operator=(SGroup<Gen...>&& other) { Base::operator=(other); return *this; }
150 
151  // all else is defined in the base class
152 };
153 
154 namespace internal {
155 
156 template<typename... Sym> struct tensor_symmetry_num_indices
157 {
158  constexpr static std::size_t value = 1;
159 };
160 
161 template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...>
162 {
163 private:
164  constexpr static std::size_t One = static_cast<std::size_t>(One_);
165  constexpr static std::size_t Two = static_cast<std::size_t>(Two_);
166  constexpr static std::size_t Three = tensor_symmetry_num_indices<Sym...>::value;
167 
168  // don't use std::max, since it's not constexpr until C++14...
169  constexpr static std::size_t maxOneTwoPlusOne = ((One > Two) ? One : Two) + 1;
170 public:
171  constexpr static std::size_t value = (maxOneTwoPlusOne > Three) ? maxOneTwoPlusOne : Three;
172 };
173 
174 template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<AntiSymmetry<One_, Two_>, Sym...>
175  : public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
176 template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<Hermiticity<One_, Two_>, Sym...>
177  : public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
178 template<int One_, int Two_, typename... Sym> struct tensor_symmetry_num_indices<AntiHermiticity<One_, Two_>, Sym...>
179  : public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
180 
228 template<std::size_t NumIndices>
229 struct tensor_symmetry_pre_analysis<NumIndices>
230 {
231  typedef StaticSGroup<> root_type;
232 };
233 
234 template<std::size_t NumIndices, typename Gen_, typename... Gens_>
235 struct tensor_symmetry_pre_analysis<NumIndices, Gen_, Gens_...>
236 {
237  constexpr static std::size_t max_static_generators = 4;
238  constexpr static std::size_t max_static_elements = 16;
239  typedef tensor_static_symgroup_if<(sizeof...(Gens_) + 1 <= max_static_generators), NumIndices, Gen_, Gens_...> helper;
240  constexpr static std::size_t possible_size = helper::size;
241 
242  typedef std::conditional_t<
243  possible_size == 0 || possible_size >= max_static_elements,
244  DynamicSGroupFromTemplateArgs<Gen_, Gens_...>,
245  typename helper::type
246  > root_type;
247 };
248 
249 template<bool instantiate, std::size_t NumIndices, typename... Gens>
250 struct tensor_static_symgroup_if
251 {
252  constexpr static std::size_t size = 0;
253  typedef void type;
254 };
255 
256 template<std::size_t NumIndices, typename... Gens>
257 struct tensor_static_symgroup_if<true, NumIndices, Gens...> : tensor_static_symgroup<NumIndices, Gens...> {};
258 
259 template<typename Tensor_>
260 struct tensor_symmetry_assign_value
261 {
262  typedef typename Tensor_::Index Index;
263  typedef typename Tensor_::Scalar Scalar;
264  constexpr static std::size_t NumIndices = Tensor_::NumIndices;
265 
266  static inline int run(const std::array<Index, NumIndices>& transformed_indices, int transformation_flags, int dummy, Tensor_& tensor, const Scalar& value_)
267  {
268  Scalar value(value_);
269  if (transformation_flags & ConjugationFlag)
270  value = numext::conj(value);
271  if (transformation_flags & NegationFlag)
272  value = -value;
273  tensor.coeffRef(transformed_indices) = value;
274  return dummy;
275  }
276 };
277 
278 template<typename Tensor_>
279 struct tensor_symmetry_calculate_flags
280 {
281  typedef typename Tensor_::Index Index;
282  constexpr static std::size_t NumIndices = Tensor_::NumIndices;
283 
284  static inline int run(const std::array<Index, NumIndices>& transformed_indices, int transform_flags, int current_flags, const std::array<Index, NumIndices>& orig_indices)
285  {
286  if (transformed_indices == orig_indices) {
287  if (transform_flags & (ConjugationFlag | NegationFlag))
288  return current_flags | GlobalImagFlag; // anti-hermitian diagonal
289  else if (transform_flags & ConjugationFlag)
290  return current_flags | GlobalRealFlag; // hermitian diagonal
291  else if (transform_flags & NegationFlag)
292  return current_flags | GlobalZeroFlag; // anti-symmetric diagonal
293  }
294  return current_flags;
295  }
296 };
297 
298 template<typename Tensor_, typename Symmetry_, int Flags = 0>
299 class tensor_symmetry_value_setter
300 {
301  public:
302  typedef typename Tensor_::Index Index;
303  typedef typename Tensor_::Scalar Scalar;
304  constexpr static std::size_t NumIndices = Tensor_::NumIndices;
305 
306  inline tensor_symmetry_value_setter(Tensor_& tensor, Symmetry_ const& symmetry, std::array<Index, NumIndices> const& indices)
307  : m_tensor(tensor), m_symmetry(symmetry), m_indices(indices) { }
308 
309  inline tensor_symmetry_value_setter<Tensor_, Symmetry_, Flags>& operator=(Scalar const& value)
310  {
311  doAssign(value);
312  return *this;
313  }
314  private:
315  Tensor_& m_tensor;
316  Symmetry_ m_symmetry;
317  std::array<Index, NumIndices> m_indices;
318 
319  inline void doAssign(Scalar const& value)
320  {
321  #ifdef EIGEN_TENSOR_SYMMETRY_CHECK_VALUES
322  int value_flags = m_symmetry.template apply<internal::tensor_symmetry_calculate_flags<Tensor_>, int>(m_indices, m_symmetry.globalFlags(), m_indices);
323  if (value_flags & GlobalRealFlag)
324  eigen_assert(numext::imag(value) == 0);
325  if (value_flags & GlobalImagFlag)
326  eigen_assert(numext::real(value) == 0);
327  #endif
328  m_symmetry.template apply<internal::tensor_symmetry_assign_value<Tensor_>, int>(m_indices, 0, m_tensor, value);
329  }
330 };
331 
332 } // end namespace internal
333 
334 } // end namespace Eigen
335 
336 #endif // EIGEN_CXX11_TENSORSYMMETRY_SYMMETRY_H
337 
338 /*
339  * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle;
340  */
#define eigen_assert(x)
Dynamic symmetry group.
Symmetry group, initialized from template arguments.
Definition: Symmetry.h:139
SGroup< Gen... > & operator=(const SGroup< Gen... > &other)
Definition: Symmetry.h:148
SGroup< Gen... > & operator=(SGroup< Gen... > &&other)
Definition: Symmetry.h:149
constexpr static std::size_t NumIndices
Definition: Symmetry.h:141
internal::tensor_symmetry_pre_analysis< NumIndices, Gen... >::root_type Base
Definition: Symmetry.h:142
SGroup(SGroup< Gen... > &&other)
Definition: Symmetry.h:147
SGroup(const SGroup< Gen... > &other)
Definition: Symmetry.h:146
Static symmetry group.
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
@ NegationFlag
Definition: Symmetry.h:18
@ ConjugationFlag
Definition: Symmetry.h:19
@ GlobalZeroFlag
Definition: Symmetry.h:25
@ GlobalRealFlag
Definition: Symmetry.h:23
@ GlobalImagFlag
Definition: Symmetry.h:24
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
const adouble & real(const adouble &x)
Definition: AdolcForward:72
adouble imag(const adouble &)
Definition: AdolcForward:73
constexpr static int Flags
Definition: Symmetry.h:72
constexpr static int One
Definition: Symmetry.h:70
constexpr static int Two
Definition: Symmetry.h:71
constexpr static int Flags
Definition: Symmetry.h:54
constexpr static int Two
Definition: Symmetry.h:53
constexpr static int One
Definition: Symmetry.h:52
constexpr static int Two
Definition: Symmetry.h:62
constexpr static int One
Definition: Symmetry.h:61
constexpr static int Flags
Definition: Symmetry.h:63
SparseMat::Index size
constexpr static int Two
Definition: Symmetry.h:44
constexpr static int Flags
Definition: Symmetry.h:45
constexpr static int One
Definition: Symmetry.h:43