10 #ifndef EIGEN_CXX11_TENSORSYMMETRY_SYMMETRY_H
11 #define EIGEN_CXX11_TENSORSYMMETRY_SYMMETRY_H
30 template<std::size_t NumIndices,
typename... Sym>
struct tensor_symmetry_pre_analysis;
31 template<std::size_t NumIndices,
typename... Sym>
struct tensor_static_symgroup;
32 template<
bool instantiate, std::size_t NumIndices,
typename... Sym>
struct tensor_static_symgroup_if;
33 template<
typename Tensor_>
struct tensor_symmetry_calculate_flags;
34 template<
typename Tensor_>
struct tensor_symmetry_assign_value;
35 template<
typename... Sym>
struct tensor_symmetry_num_indices;
39 template<
int One_,
int Two_>
42 static_assert(One_ != Two_,
"Symmetries must cover distinct indices.");
43 constexpr
static int One = One_;
44 constexpr
static int Two = Two_;
48 template<
int One_,
int Two_>
51 static_assert(One_ != Two_,
"Symmetries must cover distinct indices.");
52 constexpr
static int One = One_;
53 constexpr
static int Two = Two_;
57 template<
int One_,
int Two_>
60 static_assert(One_ != Two_,
"Symmetries must cover distinct indices.");
61 constexpr
static int One = One_;
62 constexpr
static int Two = Two_;
66 template<
int One_,
int Two_>
69 static_assert(One_ != Two_,
"Symmetries must cover distinct indices.");
70 constexpr
static int One = One_;
71 constexpr
static int Two = Two_;
100 template<
typename... Gen>
122 template<
typename... Gen>
137 template<
typename... Gen>
138 class SGroup :
public internal::tensor_symmetry_pre_analysis<internal::tensor_symmetry_num_indices<Gen...>::value, Gen...>::root_type
141 constexpr
static std::size_t
NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
142 typedef typename internal::tensor_symmetry_pre_analysis<
NumIndices, Gen...>::root_type
Base;
156 template<
typename... Sym>
struct tensor_symmetry_num_indices
158 constexpr
static std::size_t value = 1;
161 template<
int One_,
int Two_,
typename... Sym>
struct tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...>
164 constexpr
static std::size_t One =
static_cast<std::size_t
>(One_);
165 constexpr
static std::size_t Two =
static_cast<std::size_t
>(Two_);
166 constexpr
static std::size_t Three = tensor_symmetry_num_indices<Sym...>::value;
169 constexpr
static std::size_t maxOneTwoPlusOne = ((One > Two) ? One : Two) + 1;
171 constexpr
static std::size_t value = (maxOneTwoPlusOne > Three) ? maxOneTwoPlusOne : Three;
174 template<
int One_,
int Two_,
typename... Sym>
struct tensor_symmetry_num_indices<AntiSymmetry<One_, Two_>, Sym...>
175 :
public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
176 template<
int One_,
int Two_,
typename... Sym>
struct tensor_symmetry_num_indices<Hermiticity<One_, Two_>, Sym...>
177 :
public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
178 template<
int One_,
int Two_,
typename... Sym>
struct tensor_symmetry_num_indices<AntiHermiticity<One_, Two_>, Sym...>
179 :
public tensor_symmetry_num_indices<Symmetry<One_, Two_>, Sym...> {};
228 template<std::
size_t NumIndices>
229 struct tensor_symmetry_pre_analysis<NumIndices>
231 typedef StaticSGroup<> root_type;
234 template<std::size_t NumIndices,
typename Gen_,
typename... Gens_>
235 struct tensor_symmetry_pre_analysis<NumIndices, Gen_, Gens_...>
237 constexpr
static std::size_t max_static_generators = 4;
238 constexpr
static std::size_t max_static_elements = 16;
239 typedef tensor_static_symgroup_if<(
sizeof...(Gens_) + 1 <= max_static_generators), NumIndices, Gen_, Gens_...> helper;
240 constexpr
static std::size_t possible_size = helper::size;
242 typedef std::conditional_t<
243 possible_size == 0 || possible_size >= max_static_elements,
244 DynamicSGroupFromTemplateArgs<Gen_, Gens_...>,
245 typename helper::type
249 template<
bool instantiate, std::size_t NumIndices,
typename... Gens>
250 struct tensor_static_symgroup_if
252 constexpr
static std::size_t
size = 0;
256 template<std::size_t NumIndices,
typename... Gens>
257 struct tensor_static_symgroup_if<true, NumIndices, Gens...> : tensor_static_symgroup<NumIndices, Gens...> {};
259 template<
typename Tensor_>
260 struct tensor_symmetry_assign_value
262 typedef typename Tensor_::Index
Index;
263 typedef typename Tensor_::Scalar Scalar;
264 constexpr
static std::size_t NumIndices = Tensor_::NumIndices;
266 static inline int run(
const std::array<Index, NumIndices>& transformed_indices,
int transformation_flags,
int dummy, Tensor_& tensor,
const Scalar& value_)
268 Scalar value(value_);
270 value = numext::conj(value);
273 tensor.coeffRef(transformed_indices) = value;
278 template<
typename Tensor_>
279 struct tensor_symmetry_calculate_flags
281 typedef typename Tensor_::Index
Index;
282 constexpr
static std::size_t NumIndices = Tensor_::NumIndices;
284 static inline int run(
const std::array<Index, NumIndices>& transformed_indices,
int transform_flags,
int current_flags,
const std::array<Index, NumIndices>& orig_indices)
286 if (transformed_indices == orig_indices) {
294 return current_flags;
298 template<
typename Tensor_,
typename Symmetry_,
int Flags = 0>
299 class tensor_symmetry_value_setter
302 typedef typename Tensor_::Index
Index;
303 typedef typename Tensor_::Scalar Scalar;
304 constexpr
static std::size_t NumIndices = Tensor_::NumIndices;
306 inline tensor_symmetry_value_setter(Tensor_& tensor, Symmetry_
const& symmetry, std::array<Index, NumIndices>
const& indices)
307 : m_tensor(tensor), m_symmetry(symmetry), m_indices(indices) { }
309 inline tensor_symmetry_value_setter<Tensor_, Symmetry_, Flags>& operator=(Scalar
const& value)
316 Symmetry_ m_symmetry;
317 std::array<Index, NumIndices> m_indices;
319 inline void doAssign(Scalar
const& value)
321 #ifdef EIGEN_TENSOR_SYMMETRY_CHECK_VALUES
322 int value_flags = m_symmetry.template apply<internal::tensor_symmetry_calculate_flags<Tensor_>,
int>(m_indices, m_symmetry.globalFlags(), m_indices);
328 m_symmetry.template apply<internal::tensor_symmetry_assign_value<Tensor_>,
int>(m_indices, 0, m_tensor, value);
Symmetry group, initialized from template arguments.
SGroup< Gen... > & operator=(const SGroup< Gen... > &other)
SGroup< Gen... > & operator=(SGroup< Gen... > &&other)
constexpr static std::size_t NumIndices
internal::tensor_symmetry_pre_analysis< NumIndices, Gen... >::root_type Base
SGroup(SGroup< Gen... > &&other)
SGroup(const SGroup< Gen... > &other)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
const adouble & real(const adouble &x)
adouble imag(const adouble &)
constexpr static int Flags
constexpr static int Flags
constexpr static int Flags
constexpr static int Flags