11 #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
12 #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
14 #include "../InternalHeaderCheck.h"
20 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide,
bool Specialized>
26 const Scalar* _tri,
Index triStride,
27 Scalar* _other,
Index otherIncr,
Index otherStride);
30 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide,
bool Specialized>
36 const Scalar* _tri,
Index triStride,
37 Scalar* _other,
Index otherIncr,
Index otherStride);
40 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide,
bool Specialized>
41 EIGEN_STRONG_INLINE
void trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, Specialized>::kernel(
43 const Scalar* _tri,
Index triStride,
44 Scalar* _other,
Index otherIncr,
Index otherStride)
46 typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
47 typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
48 TriMapper
tri(_tri, triStride);
49 OtherMapper other(_other, otherStride, otherIncr);
52 conj_if<Conjugate>
conj;
58 Index i = IsLower ? k : -k-1;
61 : IsLower ?
i+1 :
i-rs;
69 const Scalar* l = &
tri(
i,s);
70 typename OtherMapper::LinearMapper r = other.getLinearMapper(s,
j);
71 for (
Index i3=0; i3<k; ++i3)
72 b +=
conj(l[i3]) * r(i3);
74 other(
i,
j) = (other(
i,
j) -
b)*
a;
78 Scalar& otherij = other(
i,
j);
81 typename OtherMapper::LinearMapper r = other.getLinearMapper(s,
j);
82 typename TriMapper::LinearMapper l =
tri.getLinearMapper(s,
i);
83 for (
Index i3=0;i3<rs;++i3)
84 r(i3) -=
b *
conj(l(i3));
91 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide,
bool Specialized>
92 EIGEN_STRONG_INLINE
void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride, Specialized>::kernel(
94 const Scalar* _tri,
Index triStride,
95 Scalar* _other,
Index otherIncr,
Index otherStride)
98 typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
99 typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
100 LhsMapper lhs(_other, otherStride, otherIncr);
101 RhsMapper rhs(_tri, triStride);
104 RhsStorageOrder = TriStorageOrder,
107 conj_if<Conjugate>
conj;
113 typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0,
j);
114 for (
Index k3=0; k3<k; ++k3)
116 Scalar
b =
conj(rhs(IsLower ?
j+1+k3 : k3,
j));
117 typename LhsMapper::LinearMapper
a = lhs.getLinearMapper(0,IsLower ?
j+1+k3 : k3);
123 Scalar inv_rjj = RealScalar(1)/
conj(rhs(
j,
j));
132 template <
typename Scalar,
typename Index,
int S
ide,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
133 struct triangular_solve_matrix<Scalar,
Index,Side,Mode,Conjugate,TriStorageOrder,
RowMajor,OtherInnerStride>
138 Scalar* _other,
Index otherIncr,
Index otherStride,
139 level3_blocking<Scalar,Scalar>& blocking)
141 triangular_solve_matrix<
146 ::run(
size,
cols,
tri, triStride, _other, otherIncr, otherStride, blocking);
152 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
153 struct triangular_solve_matrix<Scalar,
Index,
OnTheLeft,Mode,Conjugate,TriStorageOrder,
ColMajor,OtherInnerStride>
157 const Scalar* _tri,
Index triStride,
158 Scalar* _other,
Index otherIncr,
Index otherStride,
159 level3_blocking<Scalar,Scalar>& blocking);
162 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
163 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
165 const Scalar* _tri,
Index triStride,
166 Scalar* _other,
Index otherIncr,
Index otherStride,
167 level3_blocking<Scalar,Scalar>& blocking)
171 std::ptrdiff_t l1, l2, l3;
174 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_L_KERNELS && EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS
176 (std::is_same<Scalar,float>::value ||
177 std::is_same<Scalar,double>::value)) ) {
182 if (
size < avx512_trsm_cutoff<Scalar>(l2,
cols, L2Cap)) {
183 trsmKernelL<Scalar,
Index, Mode, Conjugate, TriStorageOrder, 1,
true>::kernel(
184 size,
cols, _tri, triStride, _other, 1, otherStride);
190 typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
191 typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
192 TriMapper
tri(_tri, triStride);
193 OtherMapper other(_other, otherStride, otherIncr);
195 typedef gebp_traits<Scalar,Scalar> Traits;
202 Index kc = blocking.kc();
205 std::size_t sizeA = kc*mc;
206 std::size_t sizeB = kc*
cols;
211 gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
212 gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, TriStorageOrder> pack_lhs;
213 gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
217 Index subcols =
cols>0 ? l2/(4 *
sizeof(Scalar) * std::max<Index>(otherStride,
size)) : 0;
218 subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
221 IsLower ? k2<size : k2>0;
222 IsLower ? k2+=kc : k2-=kc)
243 for (
Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
245 Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
248 Index i = IsLower ? k2+k1 : k2-k1;
249 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_L_KERNELS
251 (std::is_same<Scalar,float>::value ||
252 std::is_same<Scalar,double>::value)) ) {
253 i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
256 trsmKernelL<Scalar,
Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride,
true>::kernel(
257 actualPanelWidth, actual_cols,
258 _tri +
i + (
i)*triStride, triStride,
259 _other +
i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
262 Index lengthTarget = actual_kc-k1-actualPanelWidth;
263 Index startBlock = IsLower ? k2+k1 : k2-k1-actualPanelWidth;
264 Index blockBOffset = IsLower ? k1 : lengthTarget;
267 pack_rhs(blockB+actual_kc*j2, other.getSubMapper(startBlock,j2), actualPanelWidth, actual_cols, actual_kc, blockBOffset);
272 Index startTarget = IsLower ? k2+k1+actualPanelWidth : k2-actual_kc;
274 pack_lhs(blockA,
tri.getSubMapper(startTarget,startBlock), actualPanelWidth, lengthTarget);
276 gebp_kernel(other.getSubMapper(startTarget,j2), blockA, blockB+actual_kc*j2, lengthTarget, actualPanelWidth, actual_cols, Scalar(-1),
277 actualPanelWidth, actual_kc, 0, blockBOffset);
284 Index start = IsLower ? k2+kc : 0;
291 pack_lhs(blockA,
tri.getSubMapper(i2, IsLower ? k2 : k2-kc), actual_kc, actual_mc);
293 gebp_kernel(other.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
cols, Scalar(-1), -1, -1, 0, 0);
302 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
303 struct triangular_solve_matrix<Scalar,
Index,
OnTheRight,Mode,Conjugate,TriStorageOrder,
ColMajor,OtherInnerStride>
307 const Scalar* _tri,
Index triStride,
308 Scalar* _other,
Index otherIncr,
Index otherStride,
309 level3_blocking<Scalar,Scalar>& blocking);
312 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
313 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
315 const Scalar* _tri,
Index triStride,
316 Scalar* _other,
Index otherIncr,
Index otherStride,
317 level3_blocking<Scalar,Scalar>& blocking)
321 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_R_KERNELS && EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
323 (std::is_same<Scalar,float>::value ||
324 std::is_same<Scalar,double>::value)) ) {
326 std::ptrdiff_t l1, l2, l3;
329 if (
size < avx512_trsm_cutoff<Scalar>(l2,
rows, L2Cap)) {
330 trsmKernelR<Scalar,
Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride,
true>::
331 kernel(
size,
rows, _tri, triStride, _other, 1, otherStride);
337 typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
338 typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
339 LhsMapper lhs(_other, otherStride, otherIncr);
340 RhsMapper rhs(_tri, triStride);
342 typedef gebp_traits<Scalar,Scalar> Traits;
344 RhsStorageOrder = TriStorageOrder,
349 Index kc = blocking.kc();
352 std::size_t sizeA = kc*mc;
353 std::size_t sizeB = kc*
size;
358 gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
359 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
360 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
361 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor, false, true> pack_lhs_panel;
364 IsLower ? k2>0 : k2<
size;
365 IsLower ? k2-=kc : k2+=kc)
368 Index actual_k2 = IsLower ? k2-actual_kc : k2 ;
370 Index startPanel = IsLower ? 0 : k2+actual_kc;
371 Index rs = IsLower ? actual_k2 :
size - actual_k2 - actual_kc;
372 Scalar* geb = blockB+actual_kc*actual_kc;
374 if (rs>0) pack_rhs(geb, rhs.getSubMapper(actual_k2,startPanel), actual_kc, rs);
379 for (
Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
381 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
382 Index actual_j2 = actual_k2 + j2;
383 Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
384 Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
387 pack_rhs_panel(blockB+j2*actual_kc,
388 rhs.getSubMapper(actual_k2+panelOffset, actual_j2),
389 panelLength, actualPanelWidth,
390 actual_kc, panelOffset);
401 for (
Index j2 = IsLower
402 ? (actual_kc - ((actual_kc%SmallPanelWidth) ?
Index(actual_kc%SmallPanelWidth)
403 :
Index(SmallPanelWidth)))
405 IsLower ? j2>=0 : j2<actual_kc;
406 IsLower ? j2-=SmallPanelWidth : j2+=SmallPanelWidth)
408 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
409 Index absolute_j2 = actual_k2 + j2;
410 Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
411 Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2;
416 gebp_kernel(lhs.getSubMapper(i2,absolute_j2),
417 blockA, blockB+j2*actual_kc,
418 actual_mc, panelLength, actualPanelWidth,
420 actual_kc, actual_kc,
421 panelOffset, panelOffset);
426 trsmKernelR<Scalar,
Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride,
true>::
427 kernel(actualPanelWidth, actual_mc,
428 _tri + absolute_j2 + absolute_j2*triStride, triStride,
429 _other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);
432 pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
433 actualPanelWidth, actual_mc,
439 gebp_kernel(lhs.getSubMapper(i2, startPanel), blockA, geb,
440 actual_mc, actual_kc, rs, Scalar(-1),
#define EIGEN_DONT_INLINE
#define EIGEN_IF_CONSTEXPR(X)
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Tridiagonalization< MatrixXf > tri
static const lastp1_t end
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
constexpr int plain_enum_max(A a, B b)
void manage_caching_sizes(Action action, std::ptrdiff_t *l1, std::ptrdiff_t *l2, std::ptrdiff_t *l3)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)