19 #define RUN_OR_ASSERT(EXPR, ERROR_MSG) \
21 MKL_LONG status = (EXPR); \
22 eigen_assert(status == DFTI_NO_ERROR && (ERROR_MSG)); \
25 inline MKL_Complex16* complex_cast(
const std::complex<double>* p) {
26 return const_cast<MKL_Complex16*
>(
reinterpret_cast<const MKL_Complex16*
>(
p));
29 inline MKL_Complex8* complex_cast(
const std::complex<float>* p) {
30 return const_cast<MKL_Complex8*
>(
reinterpret_cast<const MKL_Complex8*
>(
p));
42 inline void configure_descriptor(std::shared_ptr<DFTI_DESCRIPTOR>& handl,
43 enum DFTI_CONFIG_VALUE precision,
44 enum DFTI_CONFIG_VALUE forward_domain,
45 MKL_LONG dimension, MKL_LONG* sizes) {
48 "Transformation dimension must be less than 3.");
50 DFTI_DESCRIPTOR_HANDLE
res =
nullptr;
52 RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, forward_domain,
54 "DftiCreateDescriptor failed.")
55 handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
56 if (forward_domain == DFTI_REAL) {
58 RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_CONJUGATE_EVEN_STORAGE,
59 DFTI_COMPLEX_COMPLEX),
60 "DftiSetValue failed.")
64 DftiCreateDescriptor(&res, precision, DFTI_COMPLEX, dimension, sizes),
65 "DftiCreateDescriptor failed.")
66 handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
69 RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE),
70 "DftiSetValue failed.")
71 RUN_OR_ASSERT(DftiCommitDescriptor(handl.get()), "DftiCommitDescriptor failed.")
79 typedef float scalar_type;
80 typedef MKL_Complex8 complex_type;
82 std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
86 enum DFTI_CONFIG_VALUE precision = DFTI_SINGLE;
88 inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
90 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
93 "DftiComputeForward failed.")
96 inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
98 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
101 "DftiComputeBackward failed.")
104 inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
106 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
109 "DftiComputeForward failed.")
112 inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
114 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
117 "DftiComputeBackward failed.")
120 inline void forward2(complex_type* dst, complex_type* src,
int n0,
int n1) {
122 MKL_LONG sizes[2] = {n0, n1};
123 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
126 "DftiComputeForward failed.")
129 inline void inverse2(complex_type* dst, complex_type* src,
int n0,
int n1) {
131 MKL_LONG sizes[2] = {n0, n1};
132 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
135 "DftiComputeBackward failed.")
140 struct plan<double> {
141 typedef double scalar_type;
142 typedef MKL_Complex16 complex_type;
144 std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
148 enum DFTI_CONFIG_VALUE precision = DFTI_DOUBLE;
150 inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
152 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
155 "DftiComputeForward failed.")
158 inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
160 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
163 "DftiComputeBackward failed.")
166 inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
168 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
171 "DftiComputeForward failed.")
174 inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
176 configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
179 "DftiComputeBackward failed.")
182 inline void forward2(complex_type* dst, complex_type* src,
int n0,
int n1) {
184 MKL_LONG sizes[2] = {n0, n1};
185 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
188 "DftiComputeForward failed.")
191 inline void inverse2(complex_type* dst, complex_type* src,
int n0,
int n1) {
193 MKL_LONG sizes[2] = {n0, n1};
194 configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
197 "DftiComputeBackward failed.")
201 template <
typename Scalar_>
202 struct imklfft_impl {
203 typedef Scalar_ Scalar;
204 typedef std::complex<Scalar> Complex;
206 inline void clear() { m_plans.clear(); }
209 inline void fwd(Complex* dst,
const Complex* src,
int nfft) {
210 MKL_LONG
size = nfft;
211 get_plan(nfft, dst, src)
212 .forward(complex_cast(dst), complex_cast(src), size);
216 inline void fwd(Complex* dst,
const Scalar* src,
int nfft) {
217 MKL_LONG
size = nfft;
218 get_plan(nfft, dst, src)
219 .forward(complex_cast(dst),
const_cast<Scalar*
>(src), nfft);
223 inline void fwd2(Complex* dst,
const Complex* src,
int n0,
int n1) {
224 get_plan(n0, n1, dst, src)
225 .forward2(complex_cast(dst), complex_cast(src), n0, n1);
229 inline void inv(Complex* dst,
const Complex* src,
int nfft) {
230 MKL_LONG
size = nfft;
231 get_plan(nfft, dst, src)
232 .inverse(complex_cast(dst), complex_cast(src), nfft);
236 inline void inv(Scalar* dst,
const Complex* src,
int nfft) {
237 MKL_LONG
size = nfft;
238 get_plan(nfft, dst, src)
239 .inverse(
const_cast<Scalar*
>(dst), complex_cast(src), nfft);
243 inline void inv2(Complex* dst,
const Complex* src,
int n0,
int n1) {
244 get_plan(n0, n1, dst, src)
245 .inverse2(complex_cast(dst), complex_cast(src), n0, n1);
249 std::map<int64_t, plan<Scalar>> m_plans;
251 inline plan<Scalar>& get_plan(
int nfft,
void* dst,
253 int inplace = dst == src ? 1 : 0;
254 int aligned = ((
reinterpret_cast<size_t>(src) & 15) |
255 (
reinterpret_cast<size_t>(dst) & 15)) == 0
258 int64_t key = ((nfft << 2) | (inplace << 1) | aligned)
265 inline plan<Scalar>& get_plan(
int n0,
int n1,
void* dst,
267 int inplace = (dst == src) ? 1 : 0;
268 int aligned = ((
reinterpret_cast<size_t>(src) & 15) |
269 (
reinterpret_cast<size_t>(dst) & 15)) == 0
273 (inplace << 1) | aligned)
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
#define RUN_OR_ASSERT(EXPR, ERROR_MSG)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)