10 #ifndef EIGEN_VISITOR_H
11 #define EIGEN_VISITOR_H
19 template <
typename Visitor,
typename Derived,
int UnrollCount,
20 bool Vectorize = (Derived::PacketAccess && functor_traits<Visitor>::PacketAccess),
bool LinearAccess =
false,
21 bool ShortCircuitEvaluation =
false>
24 template <
typename Visitor,
bool ShortCircuitEvaluation = false>
25 struct short_circuit_eval_impl {
29 template <
typename Visitor>
30 struct short_circuit_eval_impl<Visitor, true> {
33 return visitor.done();
38 template <
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize,
bool ShortCircuitEvaluation>
39 struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, false, ShortCircuitEvaluation> {
41 using Scalar =
typename Derived::Scalar;
42 using Packet =
typename packet_traits<Scalar>::type;
43 static constexpr
bool RowMajor = Derived::IsRowMajor;
44 static constexpr
int RowsAtCompileTime = Derived::RowsAtCompileTime;
45 static constexpr
int ColsAtCompileTime = Derived::ColsAtCompileTime;
48 static constexpr
bool CanVectorize(
int K) {
49 constexpr
int InnerSizeAtCompileTime =
RowMajor ? ColsAtCompileTime : RowsAtCompileTime;
50 if(InnerSizeAtCompileTime < PacketSize)
return false;
51 return Vectorize && (InnerSizeAtCompileTime - (K % InnerSizeAtCompileTime) >= PacketSize);
55 bool Empty = (K == UnrollCount),
56 std::enable_if_t<Empty, bool> =
true>
57 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived&, Visitor&) {}
60 bool Empty = (K == UnrollCount),
61 bool Initialize = (K == 0),
62 bool DoVectorOp = CanVectorize(K),
63 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> =
true>
66 visitor.init(
mat.coeff(0, 0), 0, 0);
71 bool Empty = (K == UnrollCount),
72 bool Initialize = (K == 0),
73 bool DoVectorOp = CanVectorize(K),
74 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> =
true>
77 static constexpr
int R =
RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
78 static constexpr
int C =
RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
79 visitor(
mat.coeff(R, C), R, C);
80 run<K + 1>(
mat, visitor);
84 bool Empty = (K == UnrollCount),
85 bool Initialize = (K == 0),
86 bool DoVectorOp = CanVectorize(K),
87 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> =
true>
90 Packet
P =
mat.template packet<Packet>(0, 0);
91 visitor.initpacket(
P, 0, 0);
92 run<PacketSize>(
mat, visitor);
96 bool Empty = (K == UnrollCount),
97 bool Initialize = (K == 0),
98 bool DoVectorOp = CanVectorize(K),
99 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> =
true>
102 static constexpr
int R =
RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
103 static constexpr
int C =
RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
104 Packet
P =
mat.template packet<Packet>(R, C);
105 visitor.packet(
P, R, C);
106 run<K + PacketSize>(
mat, visitor);
111 template <
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize,
bool ShortCircuitEvaluation>
112 struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, true, ShortCircuitEvaluation> {
114 using Scalar =
typename Derived::Scalar;
115 using Packet =
typename packet_traits<Scalar>::type;
118 static constexpr
bool CanVectorize(
int K) {
119 return Vectorize && ((UnrollCount - K) >= PacketSize);
124 bool Empty = (K == UnrollCount),
125 std::enable_if_t<Empty, bool> =
true>
126 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived&, Visitor&) {}
130 bool Empty = (K == UnrollCount),
131 bool Initialize = (K == 0),
132 bool DoVectorOp = CanVectorize(K),
133 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> =
true>
135 visitor.init(
mat.coeff(0), 0);
136 run<1>(
mat, visitor);
141 bool Empty = (K == UnrollCount),
142 bool Initialize = (K == 0),
143 bool DoVectorOp = CanVectorize(K),
144 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> =
true>
146 visitor(
mat.coeff(K), K);
147 run<K + 1>(
mat, visitor);
152 bool Empty = (K == UnrollCount),
153 bool Initialize = (K == 0),
154 bool DoVectorOp = CanVectorize(K),
155 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> =
true>
157 Packet
P =
mat.template packet<Packet>(0);
158 visitor.initpacket(
P, 0);
159 run<PacketSize>(
mat, visitor);
164 bool Empty = (K == UnrollCount),
165 bool Initialize = (K == 0),
166 bool DoVectorOp = CanVectorize(K),
167 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> =
true>
169 Packet
P =
mat.template packet<Packet>(K);
170 visitor.packet(
P, K);
171 run<K + PacketSize>(
mat, visitor);
176 template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
177 struct visitor_impl<Visitor, Derived,
Dynamic, false, false, ShortCircuitEvaluation> {
178 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
179 static constexpr
bool RowMajor = Derived::IsRowMajor;
184 if (innerSize == 0 || outerSize == 0)
return;
186 visitor.init(
mat.coeff(0, 0), 0, 0);
187 if (short_circuit::run(visitor))
return;
188 for (
Index i = 1;
i < innerSize; ++
i) {
191 visitor(
mat.coeff(r,
c), r,
c);
195 for (
Index j = 1;
j < outerSize;
j++) {
196 for (
Index i = 0;
i < innerSize; ++
i) {
199 visitor(
mat.coeff(r,
c), r,
c);
207 template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
208 struct visitor_impl<Visitor, Derived,
Dynamic, true, false, ShortCircuitEvaluation> {
209 using Scalar =
typename Derived::Scalar;
210 using Packet =
typename packet_traits<Scalar>::type;
212 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
213 static constexpr
bool RowMajor = Derived::IsRowMajor;
218 if (innerSize == 0 || outerSize == 0)
return;
221 if (innerSize < PacketSize) {
222 visitor.init(
mat.coeff(0, 0), 0, 0);
225 Packet
p =
mat.template packet<Packet>(0, 0);
226 visitor.initpacket(
p, 0, 0);
230 for (;
i + PacketSize - 1 < innerSize;
i += PacketSize) {
233 Packet
p =
mat.template packet<Packet>(r,
c);
234 visitor.packet(
p, r,
c);
237 for (;
i < innerSize; ++
i) {
240 visitor(
mat.coeff(r,
c), r,
c);
244 for (
Index j = 1;
j < outerSize;
j++) {
246 for (;
i + PacketSize - 1 < innerSize;
i += PacketSize) {
249 Packet
p =
mat.template packet<Packet>(r,
c);
250 visitor.packet(
p, r,
c);
253 for (;
i < innerSize; ++
i) {
256 visitor(
mat.coeff(r,
c), r,
c);
264 template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
265 struct visitor_impl<Visitor, Derived,
Dynamic, false, true, ShortCircuitEvaluation> {
266 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
270 if (
size == 0)
return;
271 visitor.init(
mat.coeff(0), 0);
274 visitor(
mat.coeff(k), k);
281 template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
282 struct visitor_impl<Visitor, Derived,
Dynamic, true, true, ShortCircuitEvaluation> {
283 using Scalar =
typename Derived::Scalar;
284 using Packet =
typename packet_traits<Scalar>::type;
286 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
290 if (
size == 0)
return;
292 if (
size < PacketSize) {
293 visitor.init(
mat.coeff(0), 0);
296 Packet
p =
mat.template packet<Packet>(k);
297 visitor.initpacket(
p, k);
301 for (; k + PacketSize - 1 <
size; k += PacketSize) {
302 Packet
p =
mat.template packet<Packet>(k);
303 visitor.packet(
p, k);
306 for (; k <
size; k++) {
307 visitor(
mat.coeff(k), k);
314 template<
typename XprType>
315 class visitor_evaluator
318 typedef evaluator<XprType> Evaluator;
319 typedef typename XprType::Scalar Scalar;
320 using Packet =
typename packet_traits<Scalar>::type;
321 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
323 static constexpr
bool PacketAccess =
static_cast<bool>(Evaluator::Flags &
PacketAccessBit);
324 static constexpr
bool LinearAccess =
static_cast<bool>(Evaluator::Flags &
LinearAccessBit);
325 static constexpr
bool IsRowMajor =
static_cast<bool>(XprType::IsRowMajor);
326 static constexpr
int RowsAtCompileTime = XprType::RowsAtCompileTime;
327 static constexpr
int ColsAtCompileTime = XprType::ColsAtCompileTime;
328 static constexpr
int XprAlignment = Evaluator::Alignment;
329 static constexpr
int CoeffReadCost = Evaluator::CoeffReadCost;
332 explicit visitor_evaluator(
const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) { }
339 template <
typename Packet,
int Alignment = Unaligned>
341 return m_evaluator.template packet<Alignment, Packet>(
row,
col);
344 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
Index index)
const {
return m_evaluator.coeff(index); }
345 template <
typename Packet,
int Alignment = XprAlignment>
347 return m_evaluator.template packet<Alignment, Packet>(index);
351 Evaluator m_evaluator;
352 const XprType &m_xpr;
355 template <
typename Derived,
typename Visitor,
bool ShortCircuitEvaulation>
357 using Evaluator = visitor_evaluator<Derived>;
364 static constexpr
int InnerSizeAtCompileTime = IsRowMajor ? ColsAtCompileTime : RowsAtCompileTime;
365 static constexpr
int OuterSizeAtCompileTime = IsRowMajor ? RowsAtCompileTime : ColsAtCompileTime;
367 static constexpr
bool LinearAccess = Evaluator::LinearAccess &&
static_cast<bool>(functor_traits<Visitor>::LinearAccess);
368 static constexpr
bool Vectorize = Evaluator::PacketAccess &&
static_cast<bool>(functor_traits<Visitor>::PacketAccess);
371 static constexpr
int VectorOps = Vectorize ? (LinearAccess ? (SizeAtCompileTime / PacketSize) : (OuterSizeAtCompileTime * (InnerSizeAtCompileTime / PacketSize))) : 0;
372 static constexpr
int ScalarOps = SizeAtCompileTime - (VectorOps * PacketSize);
374 static constexpr
int TotalOps = VectorOps + ScalarOps;
376 static constexpr
int UnrollCost = int(Evaluator::CoeffReadCost) + int(functor_traits<Visitor>::Cost);
378 static constexpr
int UnrollCount = Unroll ? int(SizeAtCompileTime) :
Dynamic;
381 using impl = visitor_impl<Visitor, Evaluator, UnrollCount, Vectorize, LinearAccess, ShortCircuitEvaulation>;
383 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const DenseBase<Derived>&
mat, Visitor& visitor) {
384 Evaluator evaluator(
mat.derived());
385 impl::run(evaluator, visitor);
410 template<
typename Derived>
411 template<
typename Visitor>
415 using impl = internal::visit_impl<Derived, Visitor,
false>;
416 impl::run(derived(), visitor);
424 template <
typename Derived>
429 coeff_visitor() :
row(-1),
col(-1),
res(0) {}
430 typedef typename Derived::Scalar Scalar;
443 template <
typename Scalar,
int NaNPropagation,
bool is_min = true>
444 struct minmax_compare {
445 typedef typename packet_traits<Scalar>::type Packet;
450 template <
typename Scalar,
int NaNPropagation>
451 struct minmax_compare<Scalar, NaNPropagation, false> {
452 typedef typename packet_traits<Scalar>::type Packet;
459 template <
typename Derived,
bool is_min,
int NaNPropagation,
461 struct minmax_coeff_visitor : coeff_visitor<Derived> {
462 using Scalar =
typename Derived::Scalar;
463 using Packet =
typename packet_traits<Scalar>::type;
464 using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
468 if (Comparator::compare(value, this->
res)) {
476 if (Comparator::compare(value, this->
res)) {
481 this->row = Derived::IsRowMajor ?
i :
i + max_idx;
482 this->col = Derived::IsRowMajor ?
j + max_idx :
j;
487 const Packet range =
preverse(plset<Packet>(Scalar(1)));
488 Packet mask =
pcmp_eq(pset1<Packet>(value),
p);
491 this->row = Derived::IsRowMajor ?
i :
i + max_idx;
492 this->col = Derived::IsRowMajor ?
j + max_idx :
j;
498 template <
typename Derived,
bool is_min>
499 struct minmax_coeff_visitor<Derived, is_min,
PropagateNumbers, false> : coeff_visitor<Derived> {
500 typedef typename Derived::Scalar Scalar;
501 using Packet =
typename packet_traits<Scalar>::type;
502 using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
515 const Packet range =
preverse(plset<Packet>(Scalar(1)));
517 Packet mask =
pcmp_eq(pset1<Packet>(value),
p);
520 this->row = Derived::IsRowMajor ?
i :
i + max_idx;
521 this->col = Derived::IsRowMajor ?
j + max_idx :
j;
535 Packet mask =
pcmp_eq(pset1<Packet>(value),
p);
538 this->row = Derived::IsRowMajor ?
i :
i + max_idx;
539 this->col = Derived::IsRowMajor ?
j + max_idx :
j;
545 template <
typename Derived,
bool is_min,
int NaNPropagation>
546 struct minmax_coeff_visitor<Derived, is_min, NaNPropagation, false> : coeff_visitor<Derived> {
547 typedef typename Derived::Scalar Scalar;
548 using Packet =
typename packet_traits<Scalar>::type;
549 using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
553 if ((value_is_nan && !(
numext::isnan)(this->
res)) || Comparator::compare(value, this->
res)) {
563 if ((value_is_nan && !(
numext::isnan)(this->
res)) || Comparator::compare(value, this->
res)) {
564 const Packet range =
preverse(plset<Packet>(Scalar(1)));
569 this->row = Derived::IsRowMajor ?
i :
i + max_idx;
570 this->col = Derived::IsRowMajor ?
j + max_idx :
j;
577 const Packet range =
preverse(plset<Packet>(Scalar(1)));
582 this->row = Derived::IsRowMajor ?
i :
i + max_idx;
583 this->col = Derived::IsRowMajor ?
j + max_idx :
j;
587 template<
typename Derived,
bool is_min,
int NaNPropagation>
588 struct functor_traits<minmax_coeff_visitor<Derived, is_min, NaNPropagation> > {
589 using Scalar =
typename Derived::Scalar;
592 LinearAccess =
false,
593 PacketAccess = packet_traits<Scalar>::HasCmp
597 template <
typename Scalar>
599 using result_type =
bool;
600 using Packet =
typename packet_traits<Scalar>::type;
613 template <
typename Scalar>
614 struct functor_traits<all_visitor<Scalar>> {
618 template <
typename Scalar>
620 using result_type =
bool;
621 using Packet =
typename packet_traits<Scalar>::type;
636 template <
typename Scalar>
637 struct functor_traits<any_visitor<Scalar>> {
641 template <
typename Scalar>
642 struct count_visitor {
643 using result_type =
Index;
644 using Packet =
typename packet_traits<Scalar>::type;
648 const Packet cst_one = pset1<Packet>(Scalar(1));
650 Scalar num_true =
predux(true_vals);
651 return static_cast<Index>(num_true);
656 if (value != Scalar(0))
res++;
659 if (value != Scalar(0))
res++;
666 template <
typename Scalar>
667 struct functor_traits<count_visitor<Scalar>> {
672 PacketAccess = packet_traits<Scalar>::HasCmp && packet_traits<Scalar>::HasAdd && !is_same<Scalar, bool>::value
689 template<
typename Derived>
690 template<
int NaNPropagation,
typename IndexType>
692 typename internal::traits<Derived>::Scalar
697 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
698 this->visit(minVisitor);
699 *rowId = minVisitor.row;
700 if (colId) *colId = minVisitor.col;
701 return minVisitor.res;
714 template<
typename Derived>
715 template<
int NaNPropagation,
typename IndexType>
717 typename internal::traits<Derived>::Scalar
723 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
724 this->visit(minVisitor);
725 *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
726 return minVisitor.res;
740 template<
typename Derived>
741 template<
int NaNPropagation,
typename IndexType>
743 typename internal::traits<Derived>::Scalar
748 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
749 this->visit(maxVisitor);
750 *rowPtr = maxVisitor.row;
751 if (colPtr) *colPtr = maxVisitor.col;
752 return maxVisitor.res;
765 template<
typename Derived>
766 template<
int NaNPropagation,
typename IndexType>
768 typename internal::traits<Derived>::Scalar
774 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
775 this->visit(maxVisitor);
776 *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
777 return maxVisitor.res;
787 template <
typename Derived>
789 using Visitor = internal::all_visitor<Scalar>;
790 using impl = internal::visit_impl<Derived, Visitor,
true>;
792 impl::run(derived(), visitor);
800 template <
typename Derived>
802 using Visitor = internal::any_visitor<Scalar>;
803 using impl = internal::visit_impl<Derived, Visitor,
true>;
805 impl::run(derived(), visitor);
813 template<
typename Derived>
817 using Visitor = internal::count_visitor<Scalar>;
818 using impl = internal::visit_impl<Derived, Visitor,
false>;
820 impl::run(derived(), visitor);
825 template <
typename Derived>
827 return derived().cwiseTypedNotEqual(derived()).any();
834 template <
typename Derived>
836 return derived().array().isFinite().all();
RowXpr row(Index i)
This is the const version of row(). */.
ColXpr col(Index i)
This is the const version of col().
Projective3d P(Matrix4d::Random())
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_PREDICT_FALSE(x)
#define EIGEN_DEVICE_FUNC
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define EIGEN_UNROLLING_LIMIT
#define EIGEN_STATIC_ASSERT_VECTOR_ONLY(TYPE)
internal::traits< Derived >::Scalar minCoeff() const
void visit(Visitor &func) const
internal::traits< Derived >::Scalar Scalar
internal::traits< Derived >::Scalar maxCoeff() const
const unsigned int PacketAccessBit
const unsigned int LinearAccessBit
Packet8f pzero(const Packet8f &)
unpacket_traits< Packet >::type predux(const Packet &a)
Packet8h ptrue(const Packet8h &a)
Packet8h pandnot(const Packet8h &a, const Packet8h &b)
Packet2cf pcmp_eq(const Packet2cf &a, const Packet2cf &b)
Packet8h pand(const Packet8h &a, const Packet8h &b)
Packet pnot(const Packet &a)
Packet pset1(const typename unpacket_traits< Packet >::type &a)
unpacket_traits< Packet >::type predux_max(const Packet &a)
Packet2cf preverse(const Packet2cf &a)
bool predux_any(const Packet4f &x)
EIGEN_ALWAYS_INLINE bool() isnan(const Eigen::bfloat16 &h)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Eigen::Index Index
The interface type of indices.