10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
13 #include "../InternalHeaderCheck.h"
25 template <
int N,
typename T1,
typename T2,
typename T3>
26 struct gemv_packet_cond {
typedef T3 type; };
28 template <
typename T1,
typename T2,
typename T3>
29 struct gemv_packet_cond<
GEMVPacketFull, T1, T2, T3> {
typedef T1 type; };
31 template <
typename T1,
typename T2,
typename T3>
32 struct gemv_packet_cond<
GEMVPacketHalf, T1, T2, T3> {
typedef T2 type; };
34 template<
typename LhsScalar,
typename RhsScalar,
int PacketSize_=GEMVPacketFull>
37 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
39 #define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
40 typedef typename gemv_packet_cond<packet_size, \
41 typename packet_traits<name ## Scalar>::type, \
42 typename packet_traits<name ## Scalar>::half, \
43 typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
44 name ## Packet ## postfix
49 #undef PACKET_DECL_COND_POSTFIX
53 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable &&
54 unpacket_traits<RhsPacket_>::vectorizable &&
61 typedef std::conditional_t<Vectorizable,LhsPacket_,LhsScalar> LhsPacket;
62 typedef std::conditional_t<Vectorizable,RhsPacket_,RhsScalar> RhsPacket;
63 typedef std::conditional_t<Vectorizable,ResPacket_,ResScalar> ResPacket;
80 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
81 struct general_matrix_vector_product<
Index,LhsScalar,LhsMapper,
ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
83 typedef gemv_traits<LhsScalar,RhsScalar> Traits;
84 typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
85 typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
87 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
89 typedef typename Traits::LhsPacket LhsPacket;
90 typedef typename Traits::RhsPacket RhsPacket;
91 typedef typename Traits::ResPacket ResPacket;
93 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
94 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
95 typedef typename HalfTraits::ResPacket ResPacketHalf;
97 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
98 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
99 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
103 const LhsMapper& lhs,
104 const RhsMapper& rhs,
109 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
110 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
112 const LhsMapper& alhs,
113 const RhsMapper& rhs,
124 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
125 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
126 conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
127 conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
129 const Index lhsStride = lhs.stride();
132 ResPacketSize = Traits::ResPacketSize,
133 ResPacketSizeHalf = HalfTraits::ResPacketSize,
134 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
135 LhsPacketSize = Traits::LhsPacketSize,
136 HasHalf = (int)ResPacketSizeHalf < (
int)ResPacketSize,
137 HasQuarter = (int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
145 const Index n_half =
rows-1*ResPacketSizeHalf+1;
146 const Index n_quarter =
rows-1*ResPacketSizeQuarter+1;
149 const Index block_cols =
cols<128 ?
cols : (lhsStride*
sizeof(LhsScalar)<32000?16:4);
150 ResPacket palpha = pset1<ResPacket>(alpha);
151 ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
152 ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
158 for(;
i<n8;
i+=ResPacketSize*8)
160 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
161 c1 = pset1<ResPacket>(ResScalar(0)),
162 c2 = pset1<ResPacket>(ResScalar(0)),
163 c3 = pset1<ResPacket>(ResScalar(0)),
164 c4 = pset1<ResPacket>(ResScalar(0)),
165 c5 = pset1<ResPacket>(ResScalar(0)),
166 c6 = pset1<ResPacket>(ResScalar(0)),
167 c7 = pset1<ResPacket>(ResScalar(0));
171 RhsPacket b0 = pset1<RhsPacket>(rhs(
j,0));
172 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
173 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,c1);
174 c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*2,
j),b0,c2);
175 c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*3,
j),b0,c3);
176 c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*4,
j),b0,c4);
177 c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*5,
j),b0,c5);
178 c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*6,
j),b0,c6);
179 c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*7,
j),b0,c7);
192 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
193 c1 = pset1<ResPacket>(ResScalar(0)),
194 c2 = pset1<ResPacket>(ResScalar(0)),
195 c3 = pset1<ResPacket>(ResScalar(0));
199 RhsPacket b0 = pset1<RhsPacket>(rhs(
j,0));
200 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
201 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,c1);
202 c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*2,
j),b0,c2);
203 c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*3,
j),b0,c3);
214 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
215 c1 = pset1<ResPacket>(ResScalar(0)),
216 c2 = pset1<ResPacket>(ResScalar(0));
220 RhsPacket b0 = pset1<RhsPacket>(rhs(
j,0));
221 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
222 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,c1);
223 c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*2,
j),b0,c2);
233 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
234 c1 = pset1<ResPacket>(ResScalar(0));
238 RhsPacket b0 = pset1<RhsPacket>(rhs(
j,0));
239 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*0,
j),b0,c0);
240 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+LhsPacketSize*1,
j),b0,c1);
248 ResPacket c0 = pset1<ResPacket>(ResScalar(0));
251 RhsPacket b0 = pset1<RhsPacket>(rhs(
j,0));
252 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
257 if(HasHalf &&
i<n_half)
259 ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
262 RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(
j,0));
263 c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(
i+0,
j),b0,c0);
265 pstoreu(
res+
i+ResPacketSizeHalf*0,
pmadd(c0,palpha_half,ploadu<ResPacketHalf>(
res+
i+ResPacketSizeHalf*0)));
266 i+=ResPacketSizeHalf;
268 if(HasQuarter &&
i<n_quarter)
270 ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
273 RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(
j,0));
274 c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(
i+0,
j),b0,c0);
276 pstoreu(
res+
i+ResPacketSizeQuarter*0,
pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(
res+
i+ResPacketSizeQuarter*0)));
277 i+=ResPacketSizeQuarter;
283 c0 += cj.pmul(lhs(
i,
j), rhs(
j,0));
299 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
300 struct general_matrix_vector_product<
Index,LhsScalar,LhsMapper,
RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
302 typedef gemv_traits<LhsScalar,RhsScalar> Traits;
303 typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
304 typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
306 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
308 typedef typename Traits::LhsPacket LhsPacket;
309 typedef typename Traits::RhsPacket RhsPacket;
310 typedef typename Traits::ResPacket ResPacket;
312 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
313 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
314 typedef typename HalfTraits::ResPacket ResPacketHalf;
316 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
317 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
318 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
322 const LhsMapper& lhs,
323 const RhsMapper& rhs,
328 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
329 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
331 const LhsMapper& alhs,
332 const RhsMapper& rhs,
341 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
342 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
343 conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
344 conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
348 const Index n8 = lhs.stride()*
sizeof(LhsScalar)>32000 ? 0 :
rows-7;
354 ResPacketSize = Traits::ResPacketSize,
355 ResPacketSizeHalf = HalfTraits::ResPacketSize,
356 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
357 LhsPacketSize = Traits::LhsPacketSize,
358 LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
359 LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
360 HasHalf = (int)ResPacketSizeHalf < (
int)ResPacketSize,
361 HasQuarter = (int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
367 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
368 c1 = pset1<ResPacket>(ResScalar(0)),
369 c2 = pset1<ResPacket>(ResScalar(0)),
370 c3 = pset1<ResPacket>(ResScalar(0)),
371 c4 = pset1<ResPacket>(ResScalar(0)),
372 c5 = pset1<ResPacket>(ResScalar(0)),
373 c6 = pset1<ResPacket>(ResScalar(0)),
374 c7 = pset1<ResPacket>(ResScalar(0));
377 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
379 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j,0);
381 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
382 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+1,
j),b0,c1);
383 c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+2,
j),b0,c2);
384 c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+3,
j),b0,c3);
385 c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+4,
j),b0,c4);
386 c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+5,
j),b0,c5);
387 c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+6,
j),b0,c6);
388 c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+7,
j),b0,c7);
390 ResScalar cc0 =
predux(c0);
391 ResScalar cc1 =
predux(c1);
392 ResScalar cc2 =
predux(c2);
393 ResScalar cc3 =
predux(c3);
394 ResScalar cc4 =
predux(c4);
395 ResScalar cc5 =
predux(c5);
396 ResScalar cc6 =
predux(c6);
397 ResScalar cc7 =
predux(c7);
400 RhsScalar b0 = rhs(
j,0);
402 cc0 += cj.pmul(lhs(
i+0,
j), b0);
403 cc1 += cj.pmul(lhs(
i+1,
j), b0);
404 cc2 += cj.pmul(lhs(
i+2,
j), b0);
405 cc3 += cj.pmul(lhs(
i+3,
j), b0);
406 cc4 += cj.pmul(lhs(
i+4,
j), b0);
407 cc5 += cj.pmul(lhs(
i+5,
j), b0);
408 cc6 += cj.pmul(lhs(
i+6,
j), b0);
409 cc7 += cj.pmul(lhs(
i+7,
j), b0);
411 res[(
i+0)*resIncr] += alpha*cc0;
412 res[(
i+1)*resIncr] += alpha*cc1;
413 res[(
i+2)*resIncr] += alpha*cc2;
414 res[(
i+3)*resIncr] += alpha*cc3;
415 res[(
i+4)*resIncr] += alpha*cc4;
416 res[(
i+5)*resIncr] += alpha*cc5;
417 res[(
i+6)*resIncr] += alpha*cc6;
418 res[(
i+7)*resIncr] += alpha*cc7;
422 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
423 c1 = pset1<ResPacket>(ResScalar(0)),
424 c2 = pset1<ResPacket>(ResScalar(0)),
425 c3 = pset1<ResPacket>(ResScalar(0));
428 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
430 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j,0);
432 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
433 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+1,
j),b0,c1);
434 c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+2,
j),b0,c2);
435 c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+3,
j),b0,c3);
437 ResScalar cc0 =
predux(c0);
438 ResScalar cc1 =
predux(c1);
439 ResScalar cc2 =
predux(c2);
440 ResScalar cc3 =
predux(c3);
443 RhsScalar b0 = rhs(
j,0);
445 cc0 += cj.pmul(lhs(
i+0,
j), b0);
446 cc1 += cj.pmul(lhs(
i+1,
j), b0);
447 cc2 += cj.pmul(lhs(
i+2,
j), b0);
448 cc3 += cj.pmul(lhs(
i+3,
j), b0);
450 res[(
i+0)*resIncr] += alpha*cc0;
451 res[(
i+1)*resIncr] += alpha*cc1;
452 res[(
i+2)*resIncr] += alpha*cc2;
453 res[(
i+3)*resIncr] += alpha*cc3;
457 ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
458 c1 = pset1<ResPacket>(ResScalar(0));
461 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
463 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j,0);
465 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+0,
j),b0,c0);
466 c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i+1,
j),b0,c1);
468 ResScalar cc0 =
predux(c0);
469 ResScalar cc1 =
predux(c1);
472 RhsScalar b0 = rhs(
j,0);
474 cc0 += cj.pmul(lhs(
i+0,
j), b0);
475 cc1 += cj.pmul(lhs(
i+1,
j), b0);
477 res[(
i+0)*resIncr] += alpha*cc0;
478 res[(
i+1)*resIncr] += alpha*cc1;
482 ResPacket c0 = pset1<ResPacket>(ResScalar(0));
483 ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
484 ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
486 for(;
j+LhsPacketSize<=
cols;
j+=LhsPacketSize)
488 RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(
j,0);
489 c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(
i,
j),b0,c0);
491 ResScalar cc0 =
predux(c0);
493 for(;
j+LhsPacketSizeHalf<=
cols;
j+=LhsPacketSizeHalf)
495 RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(
j,0);
496 c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(
i,
j),b0,c0_h);
501 for(;
j+LhsPacketSizeQuarter<=
cols;
j+=LhsPacketSizeQuarter)
503 RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(
j,0);
504 c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(
i,
j),b0,c0_q);
510 cc0 += cj.pmul(lhs(
i,
j), rhs(
j,0));
512 res[
i*resIncr] += alpha*cc0;
#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size)
#define eigen_internal_assert(x)
#define EIGEN_UNUSED_VARIABLE(var)
#define EIGEN_DEVICE_FUNC
#define EIGEN_DONT_INLINE
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
unpacket_traits< Packet >::type predux(const Packet &a)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
void pstoreu(Scalar *to, const Packet &from)
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.