19 template <
typename Scalar_>
22 typedef Scalar_ Scalar;
23 typedef std::complex<Scalar> Complex;
24 std::vector<Complex> m_twiddles;
25 std::vector<int> m_stageRadix;
26 std::vector<int> m_stageRemainder;
27 std::vector<Complex> m_scratchBuf;
30 inline void make_twiddles(
int nfft,
bool inverse)
35 m_twiddles.resize(nfft);
36 double phinc = 0.25 * double(EIGEN_PI) / nfft;
37 Scalar flip =
inverse ? Scalar(1) : Scalar(-1);
38 m_twiddles[0] = Complex(Scalar(1), Scalar(0));
40 m_twiddles[nfft/2] = Complex(Scalar(-1), Scalar(0));
44 Scalar
c = Scalar(
cos(i*8*phinc));
45 Scalar s = Scalar(
sin(i*8*phinc));
46 m_twiddles[
i] = Complex(c, s*flip);
47 m_twiddles[nfft-
i] = Complex(c, -s*flip);
51 Scalar
c = Scalar(
cos((2*nfft-8*i)*phinc));
52 Scalar s = Scalar(
sin((2*nfft-8*i)*phinc));
53 m_twiddles[
i] = Complex(s, c*flip);
54 m_twiddles[nfft-
i] = Complex(s, -c*flip);
58 Scalar
c = Scalar(
cos((8*i-2*nfft)*phinc));
59 Scalar s = Scalar(
sin((8*i-2*nfft)*phinc));
60 m_twiddles[
i] = Complex(-s, c*flip);
61 m_twiddles[nfft-
i] = Complex(-s, -c*flip);
65 Scalar
c = Scalar(
cos((4*nfft-8*i)*phinc));
66 Scalar s = Scalar(
sin((4*nfft-8*i)*phinc));
67 m_twiddles[
i] = Complex(-c, s*flip);
68 m_twiddles[nfft-
i] = Complex(-c, -s*flip);
72 void factorize(
int nfft)
82 default:
p += 2;
break;
88 m_stageRadix.push_back(p);
89 m_stageRemainder.push_back(n);
91 m_scratchBuf.resize(p);
95 template <
typename Src_>
97 void work(
int stage,Complex * xout,
const Src_ * xin,
size_t fstride,
size_t in_stride)
99 int p = m_stageRadix[stage];
100 int m = m_stageRemainder[stage];
101 Complex * Fout_beg = xout;
102 Complex * Fout_end = xout +
p*
m;
110 work(stage+1, xout , xin, fstride*p,in_stride);
111 xin += fstride*in_stride;
112 }
while( (xout += m) != Fout_end );
116 xin += fstride*in_stride;
117 }
while(++xout != Fout_end );
123 case 2: bfly2(xout,fstride,m);
break;
124 case 3: bfly3(xout,fstride,m);
break;
125 case 4: bfly4(xout,fstride,m);
break;
126 case 5: bfly5(xout,fstride,m);
break;
127 default: bfly_generic(xout,fstride,m,p);
break;
132 void bfly2( Complex * Fout,
const size_t fstride,
int m)
134 for (
int k=0;k<
m;++k) {
135 Complex t = Fout[
m+k] * m_twiddles[k*fstride];
136 Fout[
m+k] = Fout[k] - t;
142 void bfly4( Complex * Fout,
const size_t fstride,
const size_t m)
145 int negative_if_inverse = m_inverse * -2 +1;
146 for (
size_t k=0;k<
m;++k) {
147 scratch[0] = Fout[k+
m] * m_twiddles[k*fstride];
148 scratch[1] = Fout[k+2*
m] * m_twiddles[k*fstride*2];
149 scratch[2] = Fout[k+3*
m] * m_twiddles[k*fstride*3];
150 scratch[5] = Fout[k] - scratch[1];
152 Fout[k] += scratch[1];
153 scratch[3] = scratch[0] + scratch[2];
154 scratch[4] = scratch[0] - scratch[2];
155 scratch[4] = Complex( scratch[4].
imag()*negative_if_inverse , -scratch[4].
real()* negative_if_inverse );
157 Fout[k+2*
m] = Fout[k] - scratch[3];
158 Fout[k] += scratch[3];
159 Fout[k+
m] = scratch[5] + scratch[4];
160 Fout[k+3*
m] = scratch[5] - scratch[4];
165 void bfly3( Complex * Fout,
const size_t fstride,
const size_t m)
168 const size_t m2 = 2*
m;
172 epi3 = m_twiddles[fstride*
m];
174 tw1=tw2=&m_twiddles[0];
177 scratch[1]=Fout[
m] * *tw1;
178 scratch[2]=Fout[
m2] * *tw2;
180 scratch[3]=scratch[1]+scratch[2];
181 scratch[0]=scratch[1]-scratch[2];
184 Fout[
m] = Complex( Fout->real() - Scalar(.5)*scratch[3].
real() , Fout->imag() - Scalar(.5)*scratch[3].
imag() );
185 scratch[0] *= epi3.imag();
187 Fout[
m2] = Complex( Fout[m].
real() + scratch[0].
imag() , Fout[m].
imag() - scratch[0].
real() );
188 Fout[
m] += Complex( -scratch[0].
imag(),scratch[0].
real() );
194 void bfly5( Complex * Fout,
const size_t fstride,
const size_t m)
196 Complex *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
199 Complex * twiddles = &m_twiddles[0];
202 ya = twiddles[fstride*
m];
203 yb = twiddles[fstride*2*
m];
212 for ( u=0; u<
m; ++u ) {
215 scratch[1] = *Fout1 * tw[u*fstride];
216 scratch[2] = *Fout2 * tw[2*u*fstride];
217 scratch[3] = *Fout3 * tw[3*u*fstride];
218 scratch[4] = *Fout4 * tw[4*u*fstride];
220 scratch[7] = scratch[1] + scratch[4];
221 scratch[10] = scratch[1] - scratch[4];
222 scratch[8] = scratch[2] + scratch[3];
223 scratch[9] = scratch[2] - scratch[3];
225 *Fout0 += scratch[7];
226 *Fout0 += scratch[8];
228 scratch[5] = scratch[0] + Complex(
229 (scratch[7].
real()*ya.real() ) + (scratch[8].real() *yb.real() ),
230 (scratch[7].imag()*ya.real()) + (scratch[8].imag()*yb.real())
233 scratch[6] = Complex(
234 (scratch[10].
imag()*ya.imag()) + (scratch[9].imag()*yb.imag()),
235 -(scratch[10].real()*ya.imag()) - (scratch[9].real()*yb.imag())
238 *Fout1 = scratch[5] - scratch[6];
239 *Fout4 = scratch[5] + scratch[6];
241 scratch[11] = scratch[0] +
243 (scratch[7].
real()*yb.real()) + (scratch[8].real()*ya.real()),
244 (scratch[7].imag()*yb.real()) + (scratch[8].imag()*ya.real())
247 scratch[12] = Complex(
248 -(scratch[10].
imag()*yb.imag()) + (scratch[9].imag()*ya.imag()),
249 (scratch[10].real()*yb.imag()) - (scratch[9].real()*ya.imag())
252 *Fout2=scratch[11]+scratch[12];
253 *Fout3=scratch[11]-scratch[12];
255 ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4;
263 const size_t fstride,
269 Complex * twiddles = &m_twiddles[0];
271 int Norig =
static_cast<int>(m_twiddles.size());
272 Complex * scratchbuf = &m_scratchBuf[0];
274 for ( u=0; u<
m; ++u ) {
276 for ( q1=0 ; q1<
p ; ++q1 ) {
277 scratchbuf[q1] = Fout[ k ];
282 for ( q1=0 ; q1<
p ; ++q1 ) {
284 Fout[ k ] = scratchbuf[0];
286 twidx +=
static_cast<int>(fstride) * k;
287 if (twidx>=Norig) twidx-=Norig;
288 t=scratchbuf[
q] * twiddles[twidx];
297 template <
typename Scalar_>
300 typedef Scalar_ Scalar;
301 typedef std::complex<Scalar> Complex;
306 m_realTwiddles.clear();
310 void fwd( Complex * dst,
const Complex *src,
int nfft)
312 get_plan(nfft,
false).work(0, dst, src, 1,1);
316 void fwd2( Complex * dst,
const Complex *src,
int n0,
int n1)
325 void inv2( Complex * dst,
const Complex *src,
int n0,
int n1)
338 void fwd( Complex * dst,
const Scalar * src,
int nfft)
342 m_tmpBuf1.resize(nfft);
343 get_plan(nfft,
false).work(0, &m_tmpBuf1[0], src, 1,1);
344 std::copy(m_tmpBuf1.begin(),m_tmpBuf1.begin()+(nfft>>1)+1,dst );
347 int ncfft2 = nfft>>2;
348 Complex * rtw = real_twiddles(ncfft2);
351 fwd( dst,
reinterpret_cast<const Complex*
> (src), ncfft);
352 Complex dc(dst[0].
real() + dst[0].
imag());
353 Complex nyquist(dst[0].
real() - dst[0].
imag());
355 for ( k=1;k <= ncfft2 ; ++k ) {
356 Complex fpk = dst[k];
357 Complex fpnk =
conj(dst[ncfft-k]);
358 Complex f1k = fpk + fpnk;
359 Complex f2k = fpk - fpnk;
360 Complex tw= f2k * rtw[k-1];
361 dst[k] = (f1k + tw) * Scalar(.5);
362 dst[ncfft-k] =
conj(f1k -tw)*Scalar(.5);
365 dst[ncfft] = nyquist;
371 void inv(Complex * dst,
const Complex *src,
int nfft)
373 get_plan(nfft,
true).work(0, dst, src, 1,1);
378 void inv( Scalar * dst,
const Complex * src,
int nfft)
381 m_tmpBuf1.resize(nfft);
382 m_tmpBuf2.resize(nfft);
383 std::copy(src,src+(nfft>>1)+1,m_tmpBuf1.begin() );
384 for (
int k=1;k<(nfft>>1)+1;++k)
385 m_tmpBuf1[nfft-k] =
conj(m_tmpBuf1[k]);
386 inv(&m_tmpBuf2[0],&m_tmpBuf1[0],nfft);
387 for (
int k=0;k<nfft;++k)
388 dst[k] = m_tmpBuf2[k].
real();
392 int ncfft2 = nfft>>2;
393 Complex * rtw = real_twiddles(ncfft2);
394 m_tmpBuf1.resize(ncfft);
395 m_tmpBuf1[0] = Complex( src[0].
real() + src[ncfft].
real(), src[0].
real() - src[ncfft].
real() );
396 for (
int k = 1; k <= ncfft / 2; ++k) {
398 Complex fnkc =
conj(src[ncfft-k]);
399 Complex fek = fk + fnkc;
400 Complex tmp = fk - fnkc;
401 Complex fok = tmp *
conj(rtw[k-1]);
402 m_tmpBuf1[k] = fek + fok;
403 m_tmpBuf1[ncfft-k] =
conj(fek - fok);
405 get_plan(ncfft,
true).work(0,
reinterpret_cast<Complex*
>(dst), &m_tmpBuf1[0], 1,1);
410 typedef kiss_cpx_fft<Scalar> PlanData;
411 typedef std::map<int,PlanData> PlanMap;
414 std::map<int, std::vector<Complex> > m_realTwiddles;
415 std::vector<Complex> m_tmpBuf1;
416 std::vector<Complex> m_tmpBuf2;
419 int PlanKey(
int nfft,
bool isinverse)
const {
return (nfft<<1) | int(isinverse); }
422 PlanData & get_plan(
int nfft,
bool inverse)
425 PlanData & pd = m_plans[ PlanKey(nfft,
inverse) ];
426 if ( pd.m_twiddles.size() == 0 ) {
427 pd.make_twiddles(nfft,
inverse);
434 Complex * real_twiddles(
int ncfft2)
437 std::vector<Complex> & twidref = m_realTwiddles[ncfft2];
438 if ( (
int)twidref.size() != ncfft2 ) {
439 twidref.resize(ncfft2);
440 int ncfft= ncfft2<<1;
441 Scalar pi =
acos( Scalar(-1) );
442 for (
int k=1;k<=ncfft2;++k)
443 twidref[k-1] =
exp( Complex(0,-pi * (Scalar(k) / ncfft + Scalar(.5)) ) );
#define EIGEN_UNUSED_VARIABLE(var)
EIGEN_ALWAYS_INLINE T sin(const T &x)
EIGEN_ALWAYS_INLINE T cos(const T &x)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_acos_op< typename Derived::Scalar >, const Derived > acos(const Eigen::ArrayBase< Derived > &x)
Eigen::AutoDiffScalar< EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Eigen::internal::remove_all_t< DerType >, typename Eigen::internal::traits< Eigen::internal::remove_all_t< DerType >>::Scalar, product) > acos(const Eigen::AutoDiffScalar< DerType > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_cos_op< typename Derived::Scalar >, const Derived > cos(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_exp_op< typename Derived::Scalar >, const Derived > exp(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_sin_op< typename Derived::Scalar >, const Derived > sin(const Eigen::ArrayBase< Derived > &x)
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_inverse_op< typename Derived::Scalar >, const Derived > inverse(const Eigen::ArrayBase< Derived > &x)