GeneralMatrixVector.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
12 
13 #include "../InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 namespace internal {
18 
23 };
24 
25 template <int N, typename T1, typename T2, typename T3>
26 struct gemv_packet_cond { typedef T3 type; };
27 
28 template <typename T1, typename T2, typename T3>
29 struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; };
30 
31 template <typename T1, typename T2, typename T3>
32 struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; };
33 
34 template<typename LhsScalar, typename RhsScalar, int PacketSize_=GEMVPacketFull>
35 class gemv_traits
36 {
37  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
38 
39 #define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
40  typedef typename gemv_packet_cond<packet_size, \
41  typename packet_traits<name ## Scalar>::type, \
42  typename packet_traits<name ## Scalar>::half, \
43  typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
44  name ## Packet ## postfix
45 
46  PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
47  PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
48  PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
49 #undef PACKET_DECL_COND_POSTFIX
50 
51 public:
52  enum {
53  Vectorizable = unpacket_traits<LhsPacket_>::vectorizable &&
54  unpacket_traits<RhsPacket_>::vectorizable &&
56  LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
57  RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
58  ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1
59  };
60 
61  typedef std::conditional_t<Vectorizable,LhsPacket_,LhsScalar> LhsPacket;
62  typedef std::conditional_t<Vectorizable,RhsPacket_,RhsScalar> RhsPacket;
63  typedef std::conditional_t<Vectorizable,ResPacket_,ResScalar> ResPacket;
64 };
65 
66 
67 /* Optimized col-major matrix * vector product:
68  * This algorithm processes the matrix per vertical panels,
69  * which are then processed horizontally per chunck of 8*PacketSize x 1 vertical segments.
70  *
71  * Mixing type logic: C += alpha * A * B
72  * | A | B |alpha| comments
73  * |real |cplx |cplx | no vectorization
74  * |real |cplx |real | alpha is converted to a cplx when calling the run function, no vectorization
75  * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp
76  * |cplx |real |real | optimal case, vectorization possible via real-cplx mul
77  *
78  * The same reasoning apply for the transposed case.
79  */
80 template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
81 struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
82 {
83  typedef gemv_traits<LhsScalar,RhsScalar> Traits;
84  typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
85  typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
86 
87  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
88 
89  typedef typename Traits::LhsPacket LhsPacket;
90  typedef typename Traits::RhsPacket RhsPacket;
91  typedef typename Traits::ResPacket ResPacket;
92 
93  typedef typename HalfTraits::LhsPacket LhsPacketHalf;
94  typedef typename HalfTraits::RhsPacket RhsPacketHalf;
95  typedef typename HalfTraits::ResPacket ResPacketHalf;
96 
97  typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
98  typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
99  typedef typename QuarterTraits::ResPacket ResPacketQuarter;
100 
101 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
102  Index rows, Index cols,
103  const LhsMapper& lhs,
104  const RhsMapper& rhs,
105  ResScalar* res, Index resIncr,
106  RhsScalar alpha);
107 };
108 
109 template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
110 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
111  Index rows, Index cols,
112  const LhsMapper& alhs,
113  const RhsMapper& rhs,
114  ResScalar* res, Index resIncr,
115  RhsScalar alpha)
116 {
117  EIGEN_UNUSED_VARIABLE(resIncr);
118  eigen_internal_assert(resIncr==1);
119 
120  // The following copy tells the compiler that lhs's attributes are not modified outside this function
121  // This helps GCC to generate propoer code.
122  LhsMapper lhs(alhs);
123 
124  conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
125  conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
126  conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
127  conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
128 
129  const Index lhsStride = lhs.stride();
130  // TODO: for padded aligned inputs, we could enable aligned reads
131  enum { LhsAlignment = Unaligned,
132  ResPacketSize = Traits::ResPacketSize,
133  ResPacketSizeHalf = HalfTraits::ResPacketSize,
134  ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
135  LhsPacketSize = Traits::LhsPacketSize,
136  HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
137  HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
138  };
139 
140  const Index n8 = rows-8*ResPacketSize+1;
141  const Index n4 = rows-4*ResPacketSize+1;
142  const Index n3 = rows-3*ResPacketSize+1;
143  const Index n2 = rows-2*ResPacketSize+1;
144  const Index n1 = rows-1*ResPacketSize+1;
145  const Index n_half = rows-1*ResPacketSizeHalf+1;
146  const Index n_quarter = rows-1*ResPacketSizeQuarter+1;
147 
148  // TODO: improve the following heuristic:
149  const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4);
150  ResPacket palpha = pset1<ResPacket>(alpha);
151  ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
152  ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
153 
154  for(Index j2=0; j2<cols; j2+=block_cols)
155  {
156  Index jend = numext::mini(j2+block_cols,cols);
157  Index i=0;
158  for(; i<n8; i+=ResPacketSize*8)
159  {
160  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
161  c1 = pset1<ResPacket>(ResScalar(0)),
162  c2 = pset1<ResPacket>(ResScalar(0)),
163  c3 = pset1<ResPacket>(ResScalar(0)),
164  c4 = pset1<ResPacket>(ResScalar(0)),
165  c5 = pset1<ResPacket>(ResScalar(0)),
166  c6 = pset1<ResPacket>(ResScalar(0)),
167  c7 = pset1<ResPacket>(ResScalar(0));
168 
169  for(Index j=j2; j<jend; j+=1)
170  {
171  RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
172  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
173  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
174  c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
175  c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3);
176  c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*4,j),b0,c4);
177  c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*5,j),b0,c5);
178  c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*6,j),b0,c6);
179  c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*7,j),b0,c7);
180  }
181  pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
182  pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
183  pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
184  pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3)));
185  pstoreu(res+i+ResPacketSize*4, pmadd(c4,palpha,ploadu<ResPacket>(res+i+ResPacketSize*4)));
186  pstoreu(res+i+ResPacketSize*5, pmadd(c5,palpha,ploadu<ResPacket>(res+i+ResPacketSize*5)));
187  pstoreu(res+i+ResPacketSize*6, pmadd(c6,palpha,ploadu<ResPacket>(res+i+ResPacketSize*6)));
188  pstoreu(res+i+ResPacketSize*7, pmadd(c7,palpha,ploadu<ResPacket>(res+i+ResPacketSize*7)));
189  }
190  if(i<n4)
191  {
192  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
193  c1 = pset1<ResPacket>(ResScalar(0)),
194  c2 = pset1<ResPacket>(ResScalar(0)),
195  c3 = pset1<ResPacket>(ResScalar(0));
196 
197  for(Index j=j2; j<jend; j+=1)
198  {
199  RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
200  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
201  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
202  c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
203  c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*3,j),b0,c3);
204  }
205  pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
206  pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
207  pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
208  pstoreu(res+i+ResPacketSize*3, pmadd(c3,palpha,ploadu<ResPacket>(res+i+ResPacketSize*3)));
209 
210  i+=ResPacketSize*4;
211  }
212  if(i<n3)
213  {
214  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
215  c1 = pset1<ResPacket>(ResScalar(0)),
216  c2 = pset1<ResPacket>(ResScalar(0));
217 
218  for(Index j=j2; j<jend; j+=1)
219  {
220  RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
221  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
222  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
223  c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*2,j),b0,c2);
224  }
225  pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
226  pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
227  pstoreu(res+i+ResPacketSize*2, pmadd(c2,palpha,ploadu<ResPacket>(res+i+ResPacketSize*2)));
228 
229  i+=ResPacketSize*3;
230  }
231  if(i<n2)
232  {
233  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
234  c1 = pset1<ResPacket>(ResScalar(0));
235 
236  for(Index j=j2; j<jend; j+=1)
237  {
238  RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
239  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*0,j),b0,c0);
240  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+LhsPacketSize*1,j),b0,c1);
241  }
242  pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
243  pstoreu(res+i+ResPacketSize*1, pmadd(c1,palpha,ploadu<ResPacket>(res+i+ResPacketSize*1)));
244  i+=ResPacketSize*2;
245  }
246  if(i<n1)
247  {
248  ResPacket c0 = pset1<ResPacket>(ResScalar(0));
249  for(Index j=j2; j<jend; j+=1)
250  {
251  RhsPacket b0 = pset1<RhsPacket>(rhs(j,0));
252  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
253  }
254  pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
255  i+=ResPacketSize;
256  }
257  if(HasHalf && i<n_half)
258  {
259  ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
260  for(Index j=j2; j<jend; j+=1)
261  {
262  RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0));
263  c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0);
264  }
265  pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0)));
266  i+=ResPacketSizeHalf;
267  }
268  if(HasQuarter && i<n_quarter)
269  {
270  ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
271  for(Index j=j2; j<jend; j+=1)
272  {
273  RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0));
274  c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0);
275  }
276  pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0)));
277  i+=ResPacketSizeQuarter;
278  }
279  for(;i<rows;++i)
280  {
281  ResScalar c0(0);
282  for(Index j=j2; j<jend; j+=1)
283  c0 += cj.pmul(lhs(i,j), rhs(j,0));
284  res[i] += alpha*c0;
285  }
286  }
287 }
288 
289 /* Optimized row-major matrix * vector product:
290  * This algorithm processes 4 rows at once that allows to both reduce
291  * the number of load/stores of the result by a factor 4 and to reduce
292  * the instruction dependency. Moreover, we know that all bands have the
293  * same alignment pattern.
294  *
295  * Mixing type logic:
296  * - alpha is always a complex (or converted to a complex)
297  * - no vectorization
298  */
299 template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
300 struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
301 {
302  typedef gemv_traits<LhsScalar,RhsScalar> Traits;
303  typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
304  typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
305 
306  typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
307 
308  typedef typename Traits::LhsPacket LhsPacket;
309  typedef typename Traits::RhsPacket RhsPacket;
310  typedef typename Traits::ResPacket ResPacket;
311 
312  typedef typename HalfTraits::LhsPacket LhsPacketHalf;
313  typedef typename HalfTraits::RhsPacket RhsPacketHalf;
314  typedef typename HalfTraits::ResPacket ResPacketHalf;
315 
316  typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
317  typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
318  typedef typename QuarterTraits::ResPacket ResPacketQuarter;
319 
320 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
321  Index rows, Index cols,
322  const LhsMapper& lhs,
323  const RhsMapper& rhs,
324  ResScalar* res, Index resIncr,
325  ResScalar alpha);
326 };
327 
328 template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
329 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
330  Index rows, Index cols,
331  const LhsMapper& alhs,
332  const RhsMapper& rhs,
333  ResScalar* res, Index resIncr,
334  ResScalar alpha)
335 {
336  // The following copy tells the compiler that lhs's attributes are not modified outside this function
337  // This helps GCC to generate propoer code.
338  LhsMapper lhs(alhs);
339 
340  eigen_internal_assert(rhs.stride()==1);
341  conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
342  conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
343  conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
344  conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
345 
346  // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
347  // processing 8 rows at once might be counter productive wrt cache.
348  const Index n8 = lhs.stride()*sizeof(LhsScalar)>32000 ? 0 : rows-7;
349  const Index n4 = rows-3;
350  const Index n2 = rows-1;
351 
352  // TODO: for padded aligned inputs, we could enable aligned reads
353  enum { LhsAlignment = Unaligned,
354  ResPacketSize = Traits::ResPacketSize,
355  ResPacketSizeHalf = HalfTraits::ResPacketSize,
356  ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
357  LhsPacketSize = Traits::LhsPacketSize,
358  LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
359  LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
360  HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
361  HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
362  };
363 
364  Index i=0;
365  for(; i<n8; i+=8)
366  {
367  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
368  c1 = pset1<ResPacket>(ResScalar(0)),
369  c2 = pset1<ResPacket>(ResScalar(0)),
370  c3 = pset1<ResPacket>(ResScalar(0)),
371  c4 = pset1<ResPacket>(ResScalar(0)),
372  c5 = pset1<ResPacket>(ResScalar(0)),
373  c6 = pset1<ResPacket>(ResScalar(0)),
374  c7 = pset1<ResPacket>(ResScalar(0));
375 
376  Index j=0;
377  for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
378  {
379  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
380 
381  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
382  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
383  c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
384  c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
385  c4 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+4,j),b0,c4);
386  c5 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+5,j),b0,c5);
387  c6 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+6,j),b0,c6);
388  c7 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+7,j),b0,c7);
389  }
390  ResScalar cc0 = predux(c0);
391  ResScalar cc1 = predux(c1);
392  ResScalar cc2 = predux(c2);
393  ResScalar cc3 = predux(c3);
394  ResScalar cc4 = predux(c4);
395  ResScalar cc5 = predux(c5);
396  ResScalar cc6 = predux(c6);
397  ResScalar cc7 = predux(c7);
398  for(; j<cols; ++j)
399  {
400  RhsScalar b0 = rhs(j,0);
401 
402  cc0 += cj.pmul(lhs(i+0,j), b0);
403  cc1 += cj.pmul(lhs(i+1,j), b0);
404  cc2 += cj.pmul(lhs(i+2,j), b0);
405  cc3 += cj.pmul(lhs(i+3,j), b0);
406  cc4 += cj.pmul(lhs(i+4,j), b0);
407  cc5 += cj.pmul(lhs(i+5,j), b0);
408  cc6 += cj.pmul(lhs(i+6,j), b0);
409  cc7 += cj.pmul(lhs(i+7,j), b0);
410  }
411  res[(i+0)*resIncr] += alpha*cc0;
412  res[(i+1)*resIncr] += alpha*cc1;
413  res[(i+2)*resIncr] += alpha*cc2;
414  res[(i+3)*resIncr] += alpha*cc3;
415  res[(i+4)*resIncr] += alpha*cc4;
416  res[(i+5)*resIncr] += alpha*cc5;
417  res[(i+6)*resIncr] += alpha*cc6;
418  res[(i+7)*resIncr] += alpha*cc7;
419  }
420  for(; i<n4; i+=4)
421  {
422  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
423  c1 = pset1<ResPacket>(ResScalar(0)),
424  c2 = pset1<ResPacket>(ResScalar(0)),
425  c3 = pset1<ResPacket>(ResScalar(0));
426 
427  Index j=0;
428  for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
429  {
430  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
431 
432  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
433  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
434  c2 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+2,j),b0,c2);
435  c3 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+3,j),b0,c3);
436  }
437  ResScalar cc0 = predux(c0);
438  ResScalar cc1 = predux(c1);
439  ResScalar cc2 = predux(c2);
440  ResScalar cc3 = predux(c3);
441  for(; j<cols; ++j)
442  {
443  RhsScalar b0 = rhs(j,0);
444 
445  cc0 += cj.pmul(lhs(i+0,j), b0);
446  cc1 += cj.pmul(lhs(i+1,j), b0);
447  cc2 += cj.pmul(lhs(i+2,j), b0);
448  cc3 += cj.pmul(lhs(i+3,j), b0);
449  }
450  res[(i+0)*resIncr] += alpha*cc0;
451  res[(i+1)*resIncr] += alpha*cc1;
452  res[(i+2)*resIncr] += alpha*cc2;
453  res[(i+3)*resIncr] += alpha*cc3;
454  }
455  for(; i<n2; i+=2)
456  {
457  ResPacket c0 = pset1<ResPacket>(ResScalar(0)),
458  c1 = pset1<ResPacket>(ResScalar(0));
459 
460  Index j=0;
461  for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
462  {
463  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j,0);
464 
465  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+0,j),b0,c0);
466  c1 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i+1,j),b0,c1);
467  }
468  ResScalar cc0 = predux(c0);
469  ResScalar cc1 = predux(c1);
470  for(; j<cols; ++j)
471  {
472  RhsScalar b0 = rhs(j,0);
473 
474  cc0 += cj.pmul(lhs(i+0,j), b0);
475  cc1 += cj.pmul(lhs(i+1,j), b0);
476  }
477  res[(i+0)*resIncr] += alpha*cc0;
478  res[(i+1)*resIncr] += alpha*cc1;
479  }
480  for(; i<rows; ++i)
481  {
482  ResPacket c0 = pset1<ResPacket>(ResScalar(0));
483  ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
484  ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
485  Index j=0;
486  for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
487  {
488  RhsPacket b0 = rhs.template load<RhsPacket,Unaligned>(j,0);
489  c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
490  }
491  ResScalar cc0 = predux(c0);
492  if (HasHalf) {
493  for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf)
494  {
495  RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0);
496  c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h);
497  }
498  cc0 += predux(c0_h);
499  }
500  if (HasQuarter) {
501  for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter)
502  {
503  RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0);
504  c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q);
505  }
506  cc0 += predux(c0_q);
507  }
508  for(; j<cols; ++j)
509  {
510  cc0 += cj.pmul(lhs(i,j), rhs(j,0));
511  }
512  res[i*resIncr] += alpha*cc0;
513  }
514 }
515 
516 } // end namespace internal
517 
518 } // end namespace Eigen
519 
520 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H
#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size)
#define eigen_internal_assert(x)
Definition: Macros.h:908
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:957
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:883
#define EIGEN_DONT_INLINE
Definition: Macros.h:844
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:321
@ RowMajor
Definition: Constants.h:323
unpacket_traits< Packet >::type predux(const Packet &a)
Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
void pstoreu(Scalar *to, const Packet &from)
EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
: InteropHeaders
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
std::ptrdiff_t j