TriangularSolverMatrix.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5 // Modifications Copyright (C) 2022 Intel Corporation
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
12 #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
13 
14 #include "../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride, bool Specialized>
21 struct trsmKernelL {
22  // Generic Implementation of triangular solve for triangular matrix on left and multiple rhs.
23  // Handles non-packed matrices.
24  static void kernel(
25  Index size, Index otherSize,
26  const Scalar* _tri, Index triStride,
27  Scalar* _other, Index otherIncr, Index otherStride);
28 };
29 
30 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride, bool Specialized>
31 struct trsmKernelR {
32  // Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
33  // Handles non-packed matrices.
34  static void kernel(
35  Index size, Index otherSize,
36  const Scalar* _tri, Index triStride,
37  Scalar* _other, Index otherIncr, Index otherStride);
38 };
39 
40 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride, bool Specialized>
41 EIGEN_STRONG_INLINE void trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, Specialized>::kernel(
42  Index size, Index otherSize,
43  const Scalar* _tri, Index triStride,
44  Scalar* _other, Index otherIncr, Index otherStride)
45  {
46  typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
47  typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
48  TriMapper tri(_tri, triStride);
49  OtherMapper other(_other, otherStride, otherIncr);
50 
51  enum { IsLower = (Mode&Lower) == Lower };
52  conj_if<Conjugate> conj;
53 
54  // tr solve
55  for (Index k=0; k<size; ++k)
56  {
57  // TODO write a small kernel handling this (can be shared with trsv)
58  Index i = IsLower ? k : -k-1;
59  Index rs = size - k - 1; // remaining size
60  Index s = TriStorageOrder==RowMajor ? (IsLower ? 0 : i+1)
61  : IsLower ? i+1 : i-rs;
62 
63  Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
64  for (Index j=0; j<otherSize; ++j)
65  {
66  if (TriStorageOrder==RowMajor)
67  {
68  Scalar b(0);
69  const Scalar* l = &tri(i,s);
70  typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
71  for (Index i3=0; i3<k; ++i3)
72  b += conj(l[i3]) * r(i3);
73 
74  other(i,j) = (other(i,j) - b)*a;
75  }
76  else
77  {
78  Scalar& otherij = other(i,j);
79  otherij *= a;
80  Scalar b = otherij;
81  typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
82  typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
83  for (Index i3=0;i3<rs;++i3)
84  r(i3) -= b * conj(l(i3));
85  }
86  }
87  }
88  }
89 
90 
91 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride, bool Specialized>
92 EIGEN_STRONG_INLINE void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, Specialized>::kernel(
93  Index size, Index otherSize,
94  const Scalar* _tri, Index triStride,
95  Scalar* _other, Index otherIncr, Index otherStride)
96 {
97  typedef typename NumTraits<Scalar>::Real RealScalar;
98  typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
99  typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
100  LhsMapper lhs(_other, otherStride, otherIncr);
101  RhsMapper rhs(_tri, triStride);
102 
103  enum {
104  RhsStorageOrder = TriStorageOrder,
105  IsLower = (Mode&Lower) == Lower
106  };
107  conj_if<Conjugate> conj;
108 
109  for (Index k=0; k<size; ++k)
110  {
111  Index j = IsLower ? size-k-1 : k;
112 
113  typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0,j);
114  for (Index k3=0; k3<k; ++k3)
115  {
116  Scalar b = conj(rhs(IsLower ? j+1+k3 : k3,j));
117  typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0,IsLower ? j+1+k3 : k3);
118  for (Index i=0; i<otherSize; ++i)
119  r(i) -= a(i) * b;
120  }
121  if((Mode & UnitDiag)==0)
122  {
123  Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
124  for (Index i=0; i<otherSize; ++i)
125  r(i) *= inv_rjj;
126  }
127  }
128 }
129 
130 
131 // if the rhs is row major, let's transpose the product
132 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
133 struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
134 {
135  static void run(
136  Index size, Index cols,
137  const Scalar* tri, Index triStride,
138  Scalar* _other, Index otherIncr, Index otherStride,
139  level3_blocking<Scalar,Scalar>& blocking)
140  {
141  triangular_solve_matrix<
142  Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
143  (Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
144  NumTraits<Scalar>::IsComplex && Conjugate,
145  TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor, OtherInnerStride>
146  ::run(size, cols, tri, triStride, _other, otherIncr, otherStride, blocking);
147  }
148 };
149 
150 /* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
151  */
152 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
153 struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
154 {
155  static EIGEN_DONT_INLINE void run(
156  Index size, Index otherSize,
157  const Scalar* _tri, Index triStride,
158  Scalar* _other, Index otherIncr, Index otherStride,
159  level3_blocking<Scalar,Scalar>& blocking);
160 };
161 
162 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
163 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
164  Index size, Index otherSize,
165  const Scalar* _tri, Index triStride,
166  Scalar* _other, Index otherIncr, Index otherStride,
167  level3_blocking<Scalar,Scalar>& blocking)
168  {
169  Index cols = otherSize;
170 
171  std::ptrdiff_t l1, l2, l3;
172  manage_caching_sizes(GetAction, &l1, &l2, &l3);
173 
174 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_L_KERNELS && EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS
175  EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
176  (std::is_same<Scalar,float>::value ||
177  std::is_same<Scalar,double>::value)) ) {
178  // Very rough cutoffs to determine when to call trsm w/o packing
179  // For small problem sizes trsmKernel compiled with clang is generally faster.
180  // TODO: Investigate better heuristics for cutoffs.
181  double L2Cap = 0.5; // 50% of L2 size
182  if (size < avx512_trsm_cutoff<Scalar>(l2, cols, L2Cap)) {
183  trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1, /*Specialized=*/true>::kernel(
184  size, cols, _tri, triStride, _other, 1, otherStride);
185  return;
186  }
187  }
188 #endif
189 
190  typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
191  typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
192  TriMapper tri(_tri, triStride);
193  OtherMapper other(_other, otherStride, otherIncr);
194 
195  typedef gebp_traits<Scalar,Scalar> Traits;
196 
197  enum {
198  SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
199  IsLower = (Mode&Lower) == Lower
200  };
201 
202  Index kc = blocking.kc(); // cache block size along the K direction
203  Index mc = (std::min)(size,blocking.mc()); // cache block size along the M direction
204 
205  std::size_t sizeA = kc*mc;
206  std::size_t sizeB = kc*cols;
207 
208  ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
209  ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
210 
211  gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
212  gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, TriStorageOrder> pack_lhs;
213  gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
214 
215  // the goal here is to subdivise the Rhs panels such that we keep some cache
216  // coherence when accessing the rhs elements
217  Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max<Index>(otherStride,size)) : 0;
218  subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
219 
220  for(Index k2=IsLower ? 0 : size;
221  IsLower ? k2<size : k2>0;
222  IsLower ? k2+=kc : k2-=kc)
223  {
224  const Index actual_kc = (std::min)(IsLower ? size-k2 : k2, kc);
225 
226  // We have selected and packed a big horizontal panel R1 of rhs. Let B be the packed copy of this panel,
227  // and R2 the remaining part of rhs. The corresponding vertical panel of lhs is split into
228  // A11 (the triangular part) and A21 the remaining rectangular part.
229  // Then the high level algorithm is:
230  // - B = R1 => general block copy (done during the next step)
231  // - R1 = A11^-1 B => tricky part
232  // - update B from the new R1 => actually this has to be performed continuously during the above step
233  // - R2 -= A21 * B => GEPP
234 
235  // The tricky part: compute R1 = A11^-1 B while updating B from R1
236  // The idea is to split A11 into multiple small vertical panels.
237  // Each panel can be split into a small triangular part T1k which is processed without optimization,
238  // and the remaining small part T2k which is processed using gebp with appropriate block strides
239  for(Index j2=0; j2<cols; j2+=subcols)
240  {
241  Index actual_cols = (std::min)(cols-j2,subcols);
242  // for each small vertical panels [T1k^T, T2k^T]^T of lhs
243  for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
244  {
245  Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
246  // tr solve
247  {
248  Index i = IsLower ? k2+k1 : k2-k1;
249 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_L_KERNELS
250  EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
251  (std::is_same<Scalar,float>::value ||
252  std::is_same<Scalar,double>::value)) ) {
253  i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
254  }
255 #endif
256  trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, /*Specialized=*/true>::kernel(
257  actualPanelWidth, actual_cols,
258  _tri + i + (i)*triStride, triStride,
259  _other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
260  }
261 
262  Index lengthTarget = actual_kc-k1-actualPanelWidth;
263  Index startBlock = IsLower ? k2+k1 : k2-k1-actualPanelWidth;
264  Index blockBOffset = IsLower ? k1 : lengthTarget;
265 
266  // update the respective rows of B from other
267  pack_rhs(blockB+actual_kc*j2, other.getSubMapper(startBlock,j2), actualPanelWidth, actual_cols, actual_kc, blockBOffset);
268 
269  // GEBP
270  if (lengthTarget>0)
271  {
272  Index startTarget = IsLower ? k2+k1+actualPanelWidth : k2-actual_kc;
273 
274  pack_lhs(blockA, tri.getSubMapper(startTarget,startBlock), actualPanelWidth, lengthTarget);
275 
276  gebp_kernel(other.getSubMapper(startTarget,j2), blockA, blockB+actual_kc*j2, lengthTarget, actualPanelWidth, actual_cols, Scalar(-1),
277  actualPanelWidth, actual_kc, 0, blockBOffset);
278  }
279  }
280  }
281 
282  // R2 -= A21 * B => GEPP
283  {
284  Index start = IsLower ? k2+kc : 0;
285  Index end = IsLower ? size : k2-kc;
286  for(Index i2=start; i2<end; i2+=mc)
287  {
288  const Index actual_mc = (std::min)(mc,end-i2);
289  if (actual_mc>0)
290  {
291  pack_lhs(blockA, tri.getSubMapper(i2, IsLower ? k2 : k2-kc), actual_kc, actual_mc);
292 
293  gebp_kernel(other.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0);
294  }
295  }
296  }
297  }
298  }
299 
300 /* Optimized triangular solver with multiple left hand sides and the triangular matrix on the right
301  */
302 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
303 struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
304 {
305  static EIGEN_DONT_INLINE void run(
306  Index size, Index otherSize,
307  const Scalar* _tri, Index triStride,
308  Scalar* _other, Index otherIncr, Index otherStride,
309  level3_blocking<Scalar,Scalar>& blocking);
310 };
311 
312 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
313 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
314  Index size, Index otherSize,
315  const Scalar* _tri, Index triStride,
316  Scalar* _other, Index otherIncr, Index otherStride,
317  level3_blocking<Scalar,Scalar>& blocking)
318  {
319  Index rows = otherSize;
320 
321 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_R_KERNELS && EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
322  EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
323  (std::is_same<Scalar,float>::value ||
324  std::is_same<Scalar,double>::value)) ) {
325  // TODO: Investigate better heuristics for cutoffs.
326  std::ptrdiff_t l1, l2, l3;
327  manage_caching_sizes(GetAction, &l1, &l2, &l3);
328  double L2Cap = 0.5; // 50% of L2 size
329  if (size < avx512_trsm_cutoff<Scalar>(l2, rows, L2Cap)) {
330  trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, /*Specialized=*/true>::
331  kernel(size, rows, _tri, triStride, _other, 1, otherStride);
332  return;
333  }
334  }
335 #endif
336 
337  typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
338  typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
339  LhsMapper lhs(_other, otherStride, otherIncr);
340  RhsMapper rhs(_tri, triStride);
341 
342  typedef gebp_traits<Scalar,Scalar> Traits;
343  enum {
344  RhsStorageOrder = TriStorageOrder,
345  SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
346  IsLower = (Mode&Lower) == Lower
347  };
348 
349  Index kc = blocking.kc(); // cache block size along the K direction
350  Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
351 
352  std::size_t sizeA = kc*mc;
353  std::size_t sizeB = kc*size;
354 
355  ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
356  ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
357 
358  gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
359  gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
360  gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
361  gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor, false, true> pack_lhs_panel;
362 
363  for(Index k2=IsLower ? size : 0;
364  IsLower ? k2>0 : k2<size;
365  IsLower ? k2-=kc : k2+=kc)
366  {
367  const Index actual_kc = (std::min)(IsLower ? k2 : size-k2, kc);
368  Index actual_k2 = IsLower ? k2-actual_kc : k2 ;
369 
370  Index startPanel = IsLower ? 0 : k2+actual_kc;
371  Index rs = IsLower ? actual_k2 : size - actual_k2 - actual_kc;
372  Scalar* geb = blockB+actual_kc*actual_kc;
373 
374  if (rs>0) pack_rhs(geb, rhs.getSubMapper(actual_k2,startPanel), actual_kc, rs);
375 
376  // triangular packing (we only pack the panels off the diagonal,
377  // neglecting the blocks overlapping the diagonal
378  {
379  for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
380  {
381  Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
382  Index actual_j2 = actual_k2 + j2;
383  Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
384  Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
385 
386  if (panelLength>0)
387  pack_rhs_panel(blockB+j2*actual_kc,
388  rhs.getSubMapper(actual_k2+panelOffset, actual_j2),
389  panelLength, actualPanelWidth,
390  actual_kc, panelOffset);
391  }
392  }
393 
394  for(Index i2=0; i2<rows; i2+=mc)
395  {
396  const Index actual_mc = (std::min)(mc,rows-i2);
397 
398  // triangular solver kernel
399  {
400  // for each small block of the diagonal (=> vertical panels of rhs)
401  for (Index j2 = IsLower
402  ? (actual_kc - ((actual_kc%SmallPanelWidth) ? Index(actual_kc%SmallPanelWidth)
403  : Index(SmallPanelWidth)))
404  : 0;
405  IsLower ? j2>=0 : j2<actual_kc;
406  IsLower ? j2-=SmallPanelWidth : j2+=SmallPanelWidth)
407  {
408  Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
409  Index absolute_j2 = actual_k2 + j2;
410  Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
411  Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2;
412 
413  // GEBP
414  if(panelLength>0)
415  {
416  gebp_kernel(lhs.getSubMapper(i2,absolute_j2),
417  blockA, blockB+j2*actual_kc,
418  actual_mc, panelLength, actualPanelWidth,
419  Scalar(-1),
420  actual_kc, actual_kc, // strides
421  panelOffset, panelOffset); // offsets
422  }
423 
424  {
425  // unblocked triangular solve
426  trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, /*Specialized=*/true>::
427  kernel(actualPanelWidth, actual_mc,
428  _tri + absolute_j2 + absolute_j2*triStride, triStride,
429  _other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);
430  }
431  // pack the just computed part of lhs to A
432  pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
433  actualPanelWidth, actual_mc,
434  actual_kc, j2);
435  }
436  }
437 
438  if (rs>0)
439  gebp_kernel(lhs.getSubMapper(i2, startPanel), blockA, geb,
440  actual_mc, actual_kc, rs, Scalar(-1),
441  -1, -1, 0, 0);
442  }
443  }
444  }
445 } // end namespace internal
446 
447 } // end namespace Eigen
448 
449 #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H
Array< int, 3, 1 > b
#define EIGEN_DONT_INLINE
Definition: Macros.h:844
#define EIGEN_IF_CONSTEXPR(X)
Definition: Macros.h:1298
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:847
Tridiagonalization< MatrixXf > tri
static const lastp1_t end
@ UnitDiag
Definition: Constants.h:215
@ Lower
Definition: Constants.h:211
@ Upper
Definition: Constants.h:213
@ ColMajor
Definition: Constants.h:321
@ RowMajor
Definition: Constants.h:323
@ OnTheLeft
Definition: Constants.h:334
@ OnTheRight
Definition: Constants.h:336
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:684
constexpr int plain_enum_max(A a, B b)
Definition: Meta.h:524
void manage_caching_sizes(Action action, std::ptrdiff_t *l1, std::ptrdiff_t *l2, std::ptrdiff_t *l3)
: InteropHeaders
Definition: Core:139
@ GetAction
Definition: Constants.h:508
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
std::ptrdiff_t j