10 #ifndef EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
11 #define EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
19 template<
typename list>
struct tensor_static_symgroup_permutate;
22 struct tensor_static_symgroup_permutate<numeric_list<int, nn...>>
24 constexpr
static std::size_t N =
sizeof...(nn);
27 constexpr
static inline std::array<T, N> run(
const std::array<T, N>& indices)
29 return {{indices[nn]...}};
33 template<
typename indices_,
int flags_>
34 struct tensor_static_symgroup_element
36 typedef indices_ indices;
37 constexpr
static int flags = flags_;
40 template<
typename Gen,
int N>
41 struct tensor_static_symgroup_element_ctor
43 typedef tensor_static_symgroup_element<
44 typename gen_numeric_list_swapped_pair<int, N, Gen::One, Gen::Two>::type,
50 struct tensor_static_symgroup_identity_ctor
52 typedef tensor_static_symgroup_element<
53 typename gen_numeric_list<int, N>::type,
58 template<
typename iib>
59 struct tensor_static_symgroup_multiply_helper
62 constexpr
static inline numeric_list<int, get<iia, iib>::value...> helper(numeric_list<int, iia...>) {
63 return numeric_list<int, get<iia, iib>::value...>();
67 template<
typename A,
typename B>
68 struct tensor_static_symgroup_multiply
71 typedef typename A::indices iia;
72 typedef typename B::indices iib;
73 constexpr
static int ffa = A::flags;
74 constexpr
static int ffb = B::flags;
77 static_assert(iia::count == iib::count,
"Cannot multiply symmetry elements with different number of indices.");
79 typedef tensor_static_symgroup_element<
80 decltype(tensor_static_symgroup_multiply_helper<iib>::helper(iia())),
85 template<
typename A,
typename B>
86 struct tensor_static_symgroup_equality
88 typedef typename A::indices iia;
89 typedef typename B::indices iib;
90 constexpr
static int ffa = A::flags;
91 constexpr
static int ffb = B::flags;
92 static_assert(iia::count == iib::count,
"Cannot compare symmetry elements with different number of indices.");
94 constexpr
static bool value = is_same<iia, iib>::value;
100 constexpr
static int flags_cmp_ = ffa ^ ffb;
105 constexpr
static bool is_zero = value && flags_cmp_ ==
NegationFlag;
106 constexpr
static bool is_real = value && flags_cmp_ ==
ConjugationFlag;
110 constexpr
static int global_flags =
116 template<std::size_t NumIndices,
typename... Gen>
117 struct tensor_static_symgroup
119 typedef StaticSGroup<Gen...> type;
123 template<
typename Index, std::size_t N,
int... ii,
int... jj>
126 return {{ idx[ii]..., idx[jj]... }};
129 template<
typename Index,
int... ii>
132 std::vector<Index> result{{ idx[ii]... }};
133 std::size_t target_size = idx.size();
134 for (std::size_t
i = result.size();
i < target_size;
i++)
135 result.push_back(idx[
i]);
139 template<
typename T>
struct tensor_static_symgroup_do_apply;
141 template<
typename first,
typename... next>
142 struct tensor_static_symgroup_do_apply<
internal::type_list<first, next...>>
144 template<
typename Op,
typename RV, std::size_t SGNumIndices,
typename Index, std::size_t NumIndices,
typename... Args>
145 static inline RV run(
const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
147 static_assert(NumIndices >= SGNumIndices,
"Can only apply symmetry group to objects that have at least the required amount of indices.");
148 typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices;
150 return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
153 template<
typename Op,
typename RV, std::size_t SGNumIndices,
typename Index,
typename... Args>
154 static inline RV run(
const std::vector<Index>& idx, RV initial, Args&&... args)
156 eigen_assert(idx.size() >= SGNumIndices &&
"Can only apply symmetry group to objects that have at least the required amount of indices.");
158 return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
162 template<EIGEN_TPL_PP_SPEC_HACK_DEF(
typename, empty)>
163 struct tensor_static_symgroup_do_apply<
internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>>
165 template<
typename Op,
typename RV, std::size_t SGNumIndices,
typename Index, std::size_t NumIndices,
typename... Args>
166 static inline RV run(
const std::array<Index, NumIndices>&, RV initial, Args&&...)
172 template<
typename Op,
typename RV, std::size_t SGNumIndices,
typename Index,
typename... Args>
173 static inline RV run(
const std::vector<Index>&, RV initial, Args&&...)
182 template<
typename... Gen>
185 constexpr
static std::size_t
NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
186 typedef internal::group_theory::enumerate_group_elements<
187 internal::tensor_static_symgroup_multiply,
188 internal::tensor_static_symgroup_equality,
189 typename internal::tensor_static_symgroup_identity_ctor<NumIndices>::type,
190 internal::type_list<typename internal::tensor_static_symgroup_element_ctor<Gen, NumIndices>::type...>
192 typedef typename group_elements::type
ge;
198 template<
typename Op,
typename RV,
typename Index, std::size_t N,
typename... Args>
199 static inline RV
apply(
const std::array<Index, N>& idx, RV initial, Args&&... args)
201 return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
204 template<
typename Op,
typename RV,
typename Index,
typename... Args>
205 static inline RV
apply(
const std::vector<Index>& idx, RV initial, Args&&... args)
208 return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
213 constexpr
static inline std::size_t
size() {
216 constexpr
static inline int globalFlags() {
return group_elements::global_flags; }
218 template<
typename Tensor_,
typename... IndexTypes>
219 inline internal::tensor_symmetry_value_setter<Tensor_,
StaticSGroup<Gen...>>
operator()(Tensor_& tensor,
typename Tensor_::Index firstIndex, IndexTypes... otherIndices)
const
221 static_assert(
sizeof...(otherIndices) + 1 == Tensor_::NumIndices,
"Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
222 return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
225 template<
typename Tensor_>
226 inline internal::tensor_symmetry_value_setter<Tensor_,
StaticSGroup<Gen...>>
operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>
const& indices)
const
228 return internal::tensor_symmetry_value_setter<Tensor_,
StaticSGroup<Gen...>>(tensor, *
this, indices);
constexpr static std::size_t static_size
static RV apply(const std::array< Index, N > &idx, RV initial, Args &&... args)
internal::tensor_symmetry_value_setter< Tensor_, StaticSGroup< Gen... > > operator()(Tensor_ &tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
constexpr static std::size_t size()
constexpr static std::size_t NumIndices
internal::group_theory::enumerate_group_elements< internal::tensor_static_symgroup_multiply, internal::tensor_static_symgroup_equality, typename internal::tensor_static_symgroup_identity_ctor< NumIndices >::type, internal::type_list< typename internal::tensor_static_symgroup_element_ctor< Gen, NumIndices >::type... > > group_elements
constexpr StaticSGroup(const StaticSGroup< Gen... > &)
internal::tensor_symmetry_value_setter< Tensor_, StaticSGroup< Gen... > > operator()(Tensor_ &tensor, std::array< typename Tensor_::Index, Tensor_::NumIndices > const &indices) const
static RV apply(const std::vector< Index > &idx, RV initial, Args &&... args)
constexpr StaticSGroup(StaticSGroup< Gen... > &&)
constexpr static int globalFlags()
constexpr static std::array< Index, N > tensor_static_symgroup_index_permute(std::array< Index, N > idx, internal::numeric_list< int, ii... >, internal::numeric_list< int, jj... >)
EIGEN_CONSTEXPR Index first(const T &x) EIGEN_NOEXCEPT
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index