10 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
11 #define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
13 #include "../InternalHeaderCheck.h"
17 template<
typename Scalar,
typename Index,
int StorageOrder,
int UpLo,
bool ConjLhs,
bool ConjRhs>
30 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
int mr,
int nr,
bool ConjLhs,
bool ConjRhs,
int ResInnerStr
ide,
int UpLo>
34 template <
typename Index,
35 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
36 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
37 int ResStorageOrder,
int ResInnerStride,
int UpLo,
int Version =
Specialized>
38 struct general_matrix_matrix_triangular_product;
41 template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
42 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
43 int ResInnerStride,
int UpLo,
int Version>
44 struct general_matrix_matrix_triangular_product<
Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,
RowMajor,ResInnerStride,UpLo,Version>
47 static EIGEN_STRONG_INLINE
void run(
Index size,
Index depth,
const LhsScalar* lhs,
Index lhsStride,
49 const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking)
51 general_matrix_matrix_triangular_product<
Index,
55 ::run(
size,depth,rhs,rhsStride,lhs,lhsStride,
res,resIncr,resStride,alpha,blocking);
59 template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
60 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
61 int ResInnerStride,
int UpLo,
int Version>
62 struct general_matrix_matrix_triangular_product<
Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,
ColMajor,ResInnerStride,UpLo,Version>
64 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
65 static EIGEN_STRONG_INLINE
void run(
Index size,
Index depth,
const LhsScalar* _lhs,
Index lhsStride,
66 const RhsScalar* _rhs,
Index rhsStride,
67 ResScalar* _res,
Index resIncr,
Index resStride,
68 const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
70 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
72 typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
73 typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
74 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
75 LhsMapper lhs(_lhs,lhsStride);
76 RhsMapper rhs(_rhs,rhsStride);
77 ResMapper
res(_res, resStride, resIncr);
79 Index kc = blocking.kc();
84 mc = (mc/Traits::nr)*Traits::nr;
86 std::size_t sizeA = kc*mc;
87 std::size_t sizeB = kc*
size;
92 gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
93 gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
94 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
95 tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo> sybb;
97 for(
Index k2=0; k2<depth; k2+=kc)
102 pack_rhs(blockB, rhs.getSubMapper(k2,0), actual_kc,
size);
108 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
115 gebp(
res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
118 sybb(_res+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
122 Index j2 = i2+actual_mc;
123 gebp(
res.getSubMapper(i2, j2), blockA, blockB+actual_kc*j2, actual_mc,
140 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
int mr,
int nr,
bool ConjLhs,
bool ConjRhs,
int ResInnerStr
ide,
int UpLo>
143 typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
144 typedef typename Traits::ResScalar ResScalar;
151 typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
152 typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
153 ResMapper
res(_res, resStride, resIncr);
154 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
155 gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
157 Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert()));
163 Index actualBlockSize = std::min<Index>(BlockSize,
size -
j);
164 const RhsScalar* actual_b = blockB+
j*depth;
167 gebp_kernel1(
res.getSubMapper(0,
j), blockA, actual_b,
j, depth, actualBlockSize, alpha,
175 gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA+depth*
i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
179 for(
Index j1=0; j1<actualBlockSize; ++j1)
181 typename ResMapper::LinearMapper r =
res.getLinearMapper(
i,
j+j1);
183 UpLo==
Lower ? i1<actualBlockSize : i1<=j1; ++i1)
184 r(i1) += buffer(i1,j1);
191 gebp_kernel1(
res.getSubMapper(
i,
j), blockA+depth*
i, actual_b,
size-
i,
192 depth, actualBlockSize, alpha, -1, -1, 0, 0);
202 template<
typename MatrixType,
typename ProductType,
int UpLo,
bool IsOuterProduct>
206 template<
typename MatrixType,
typename ProductType,
int UpLo>
214 typedef internal::blas_traits<Lhs> LhsBlasTraits;
215 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
220 typedef internal::blas_traits<Rhs> RhsBlasTraits;
221 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
225 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
228 mat.template triangularView<UpLo>().setZero();
232 UseLhsDirectly = ActualLhs_::InnerStrideAtCompileTime==1,
233 UseRhsDirectly = ActualRhs_::InnerStrideAtCompileTime==1
236 internal::gemv_static_vector_if<Scalar,Lhs::SizeAtCompileTime,Lhs::MaxSizeAtCompileTime,!UseLhsDirectly> static_lhs;
238 (UseLhsDirectly ?
const_cast<Scalar*
>(actualLhs.data()) : static_lhs.data()));
241 internal::gemv_static_vector_if<Scalar,Rhs::SizeAtCompileTime,Rhs::MaxSizeAtCompileTime,!UseRhsDirectly> static_rhs;
243 (UseRhsDirectly ?
const_cast<Scalar*
>(actualRhs.data()) : static_rhs.data()));
250 ::run(actualLhs.size(),
mat.data(),
mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha);
254 template<
typename MatrixType,
typename ProductType,
int UpLo>
260 typedef internal::blas_traits<Lhs> LhsBlasTraits;
261 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
266 typedef internal::blas_traits<Rhs> RhsBlasTraits;
267 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
271 typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
274 mat.template triangularView<UpLo>().setZero();
277 IsRowMajor = (internal::traits<MatrixType>::Flags&
RowMajorBit) ? 1 : 0,
278 LhsIsRowMajor = ActualLhs_::Flags&
RowMajorBit ? 1 : 0,
279 RhsIsRowMajor = ActualRhs_::Flags&
RowMajorBit ? 1 : 0,
286 Index depth = actualLhs.cols();
288 typedef internal::gemm_blocking_space<IsRowMajor ?
RowMajor :
ColMajor,
typename Lhs::Scalar,
typename Rhs::Scalar,
289 MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, ActualRhs_::MaxColsAtCompileTime> BlockingType;
291 BlockingType blocking(
size,
size, depth, 1,
false);
293 internal::general_matrix_matrix_triangular_product<
Index,
294 typename Lhs::Scalar, LhsIsRowMajor ?
RowMajor :
ColMajor, LhsBlasTraits::NeedToConjugate,
295 typename Rhs::Scalar, RhsIsRowMajor ?
RowMajor :
ColMajor, RhsBlasTraits::NeedToConjugate,
298 &actualLhs.coeffRef(SkipDiag&&(UpLo&
Lower)==
Lower ? 1 : 0,0), actualLhs.outerStride(),
299 &actualRhs.coeffRef(0,SkipDiag&&(UpLo&
Upper)==
Upper ? 1 : 0), actualRhs.outerStride(),
300 mat.data() + (SkipDiag ? (
bool(IsRowMajor) != ((UpLo&
Lower)==
Lower) ?
mat.innerStride() :
mat.outerStride() ) : 0),
301 mat.innerStride(),
mat.outerStride(), actualAlpha, blocking);
305 template<
typename MatrixType,
unsigned int UpLo>
306 template<
typename ProductType>
307 EIGEN_DEVICE_FUNC TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(
const ProductType& prod,
const Scalar& alpha,
bool beta)
310 eigen_assert(derived().nestedExpression().
rows() == prod.rows() && derived().
cols() == prod.cols());
312 general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta);
IndexedView_or_Block operator()(const RowIndices &rowIndices, const ColIndices &colIndices)
#define EIGEN_DEVICE_FUNC
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define EIGEN_STATIC_ASSERT(X, MSG)
internal::traits< Derived >::Scalar Scalar
A matrix or vector expression mapping an existing array of data.
The matrix class, also used for vectors and row-vectors.
const unsigned int RowMajorBit
bfloat16() max(const bfloat16 &a, const bfloat16 &b)
bfloat16() min(const bfloat16 &a, const bfloat16 &b)
constexpr int plain_enum_min(A a, B b)
constexpr int plain_enum_max(A a, B b)
typename remove_all< T >::type remove_all_t
typename add_const_on_value_type< T >::type add_const_on_value_type_t
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
static void run(MatrixType &mat, const ProductType &prod, const typename MatrixType::Scalar &alpha, bool beta)
static void run(MatrixType &mat, const ProductType &prod, const typename MatrixType::Scalar &alpha, bool beta)