10 #ifndef EIGEN_SPARSE_PERMUTATION_H
11 #define EIGEN_SPARSE_PERMUTATION_H
21 template<typename ExpressionType, typename PlainObjectType, bool NeedEval = !is_same<ExpressionType, PlainObjectType>::value>
24 XprHelper(
const ExpressionType& xpr) : m_xpr(xpr) {}
25 inline const PlainObjectType& xpr()
const {
return m_xpr; }
27 const PlainObjectType m_xpr;
29 template<
typename ExpressionType,
typename PlainObjectType>
30 struct XprHelper<ExpressionType, PlainObjectType, false>
32 XprHelper(
const ExpressionType& xpr) : m_xpr(xpr) {}
33 inline const PlainObjectType& xpr()
const {
return m_xpr; }
35 const PlainObjectType& m_xpr;
38 template<
typename PermDerived,
bool NeedInverseEval>
41 using IndicesType =
typename PermDerived::IndicesType;
42 using PermutationIndex =
typename IndicesType::Scalar;
43 using type = PermutationMatrix<IndicesType::SizeAtCompileTime, IndicesType::MaxSizeAtCompileTime, PermutationIndex>;
44 PermHelper(
const PermDerived& perm) : m_perm(perm.
inverse()) {}
45 inline const type& perm()
const {
return m_perm; }
49 template<
typename PermDerived>
50 struct PermHelper<PermDerived, false>
52 using type = PermDerived;
53 PermHelper(
const PermDerived& perm) : m_perm(perm) {}
54 inline const type& perm()
const {
return m_perm; }
59 template<
typename ExpressionType,
int S
ide,
bool Transposed>
60 struct permutation_matrix_product<ExpressionType, Side, Transposed, SparseShape>
62 using MatrixType =
typename nested_eval<ExpressionType, 1>::type;
63 using MatrixTypeCleaned = remove_all_t<MatrixType>;
65 using Scalar =
typename MatrixTypeCleaned::Scalar;
66 using StorageIndex =
typename MatrixTypeCleaned::StorageIndex;
69 using ReturnType = SparseMatrix<Scalar, MatrixTypeCleaned::IsRowMajor ? RowMajor : ColMajor, StorageIndex>;
70 using TmpHelper = XprHelper<ExpressionType, ReturnType>;
72 static constexpr
bool NeedOuterPermutation = ExpressionType::IsRowMajor ? Side ==
OnTheLeft : Side ==
OnTheRight;
73 static constexpr
bool NeedInversePermutation = Transposed ? Side ==
OnTheLeft : Side ==
OnTheRight;
75 template <
typename Dest,
typename PermutationType>
76 static inline void permute_outer(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
81 const TmpHelper tmpHelper(xpr);
82 const ReturnType& tmp = tmpHelper.xpr();
84 ReturnType result(tmp.rows(), tmp.cols());
86 for (
Index j = 0;
j < tmp.outerSize();
j++) {
87 Index jp = perm.indices().coeff(
j);
88 Index jsrc = NeedInversePermutation ? jp :
j;
89 Index jdst = NeedInversePermutation ?
j : jp;
90 Index begin = tmp.outerIndexPtr()[jsrc];
91 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
92 result.outerIndexPtr()[jdst + 1] +=
end - begin;
95 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1,
96 result.outerIndexPtr());
97 result.resizeNonZeros(result.nonZeros());
99 for (
Index j = 0;
j < tmp.outerSize();
j++) {
100 Index jp = perm.indices().coeff(
j);
101 Index jsrc = NeedInversePermutation ? jp :
j;
102 Index jdst = NeedInversePermutation ?
j : jp;
103 Index begin = tmp.outerIndexPtr()[jsrc];
104 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
105 Index target = result.outerIndexPtr()[jdst];
106 smart_copy(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() +
end, result.innerIndexPtr() + target);
107 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() +
end, result.valuePtr() + target);
109 dst = std::move(result);
112 template <
typename Dest,
typename PermutationType>
113 static inline void permute_inner(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) {
114 using InnerPermHelper = PermHelper<PermutationType, NeedInversePermutation>;
115 using InnerPermType =
typename InnerPermHelper::type;
120 const TmpHelper tmpHelper(xpr);
121 const ReturnType& tmp = tmpHelper.xpr();
125 const InnerPermHelper permHelper(perm);
126 const InnerPermType& innerPerm = permHelper.perm();
128 ReturnType result(tmp.rows(), tmp.cols());
130 for (
Index j = 0;
j < tmp.outerSize();
j++) {
131 Index begin = tmp.outerIndexPtr()[
j];
132 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[
j + 1] : begin + tmp.innerNonZeroPtr()[
j];
133 result.outerIndexPtr()[
j + 1] +=
end - begin;
136 std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
137 result.resizeNonZeros(result.nonZeros());
139 for (
Index j = 0;
j < tmp.outerSize();
j++) {
140 Index begin = tmp.outerIndexPtr()[
j];
141 Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[
j + 1] : begin + tmp.innerNonZeroPtr()[
j];
142 Index target = result.outerIndexPtr()[
j];
143 std::transform(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() +
end, result.innerIndexPtr() + target,
144 [&innerPerm](StorageIndex
i) { return innerPerm.indices().coeff(i); });
145 smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() +
end, result.valuePtr() + target);
148 result.sortInnerIndices();
149 dst = std::move(result);
152 template <
typename Dest,
typename PermutationType,
bool DoOuter = NeedOuterPermutation, std::enable_if_t<DoOuter,
int> = 0>
153 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) { permute_outer(dst, perm, xpr); }
155 template <
typename Dest,
typename PermutationType,
bool DoOuter = NeedOuterPermutation, std::enable_if_t<!DoOuter,
int> = 0>
156 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr) { permute_inner(dst, perm, xpr); }
163 template <
int ProductTag>
struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> {
typedef Sparse ret; };
164 template <
int ProductTag>
struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> {
typedef Sparse ret; };
170 template<
typename Lhs,
typename Rhs,
int ProductTag>
171 struct product_evaluator<Product<Lhs, Rhs,
AliasFreeProduct>, ProductTag, PermutationShape, SparseShape>
172 :
public evaluator<typename permutation_matrix_product<Rhs,OnTheLeft,false,SparseShape>::ReturnType>
174 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
175 typedef typename permutation_matrix_product<Rhs,OnTheLeft,false,SparseShape>::ReturnType PlainObject;
176 typedef evaluator<PlainObject> Base;
182 explicit product_evaluator(
const XprType& xpr)
185 internal::construct_at<Base>(
this, m_result);
186 generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
190 PlainObject m_result;
193 template<
typename Lhs,
typename Rhs,
int ProductTag>
194 struct product_evaluator<Product<Lhs, Rhs,
AliasFreeProduct>, ProductTag, SparseShape, PermutationShape >
195 :
public evaluator<typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType>
197 typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
198 typedef typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType PlainObject;
199 typedef evaluator<PlainObject> Base;
205 explicit product_evaluator(
const XprType& xpr)
208 ::new (
static_cast<Base*
>(
this)) Base(m_result);
209 generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
213 PlainObject m_result;
220 template <typename SparseDerived, typename PermDerived>
228 template <
typename SparseDerived,
typename PermDerived>
236 template <
typename SparseDerived,
typename PermutationType>
244 template <
typename SparseDerived,
typename PermutationType>
Matrix< float, 1, Dynamic > MatrixType
Base class for permutations.
Expression of the product of two arbitrary matrices or vectors.
Base class of any sparse matrices or sparse expressions.
static const lastp1_t end
const unsigned int EvalBeforeNestingBit
void smart_copy(const T *start, const T *end, T *target)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Product< MatrixDerived, PermutationDerived, AliasFreeProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)