10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_FFT_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_FFT_H
37 std::complex<T>
operator() (
const T& val)
const {
return std::complex<T>(val, 0); }
43 std::complex<T>
operator() (
const std::complex<T>& val)
const {
return val; }
46 template <
int ResultType>
struct PartOf {
47 template <
typename T>
T operator() (
const T& val)
const {
return val; }
51 template <
typename T>
T operator() (
const std::complex<T>& val)
const {
return val.real(); }
55 template <
typename T>
T operator() (
const std::complex<T>& val)
const {
return val.imag(); }
59 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDir>
63 typedef typename std::complex<RealScalar> ComplexScalar;
64 typedef typename XprTraits::Scalar InputScalar;
65 typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar> OutputScalar;
66 typedef typename XprTraits::StorageKind StorageKind;
67 typedef typename XprTraits::Index
Index;
68 typedef typename XprType::Nested Nested;
69 typedef std::remove_reference_t<Nested> Nested_;
70 static constexpr
int NumDimensions = XprTraits::NumDimensions;
71 static constexpr
int Layout = XprTraits::Layout;
75 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDirection>
77 typedef const TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>& type;
80 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDirection>
82 typedef TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection> type;
87 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDir>
90 typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar
Scalar;
93 typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
OutputScalar;
95 typedef typename Eigen::internal::nested<TensorFFTOp>::type
Nested;
96 typedef typename Eigen::internal::traits<TensorFFTOp>::StorageKind
StorageKind;
97 typedef typename Eigen::internal::traits<TensorFFTOp>::Index
Index;
116 template <
typename FFT,
typename ArgType,
typename Device,
int FFTResultType,
int FFTDir>
120 static constexpr
int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
128 typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
OutputScalar;
131 static constexpr
int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
140 PreferBlockAccess =
false,
151 for (
int i = 0;
i < NumDims; ++
i) {
153 m_dimensions[
i] = input_dims[
i];
158 for (
int i = 1;
i < NumDims; ++
i) {
159 m_strides[
i] = m_strides[
i - 1] * m_dimensions[
i - 1];
162 m_strides[NumDims - 1] = 1;
163 for (
int i = NumDims - 2;
i >= 0; --
i) {
164 m_strides[
i] = m_strides[
i + 1] * m_dimensions[
i + 1];
167 m_size = m_dimensions.TotalSize();
175 m_impl.evalSubExprsIfNeeded(NULL);
198 template <
int LoadMode>
201 return internal::ploadt<PacketReturnType, LoadMode>(
m_data + index);
213 const bool write_to_out = internal::is_same<OutputScalar, ComplexScalar>::value;
216 for (
Index i = 0;
i < m_size; ++
i) {
220 for (
size_t i = 0;
i < m_fft.size(); ++
i) {
223 Index line_len = m_dimensions[dim];
226 const bool is_power_of_two = isPowerOfTwo(line_len);
227 const Index good_composite = is_power_of_two ? 0 : findGoodComposite(line_len);
228 const Index log_len = is_power_of_two ? getLog2(line_len) : getLog2(good_composite);
233 if (!is_power_of_two) {
260 for (
int j = 0;
j < line_len + 1; ++
j) {
267 for (
Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) {
268 const Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
271 const Index stride = m_strides[dim];
275 Index offset = base_offset;
276 for (
int j = 0;
j < line_len; ++
j, offset += stride) {
277 line_buf[
j] = buf[offset];
282 if (is_power_of_two) {
283 processDataLineCooleyTukey(line_buf, line_len, log_len);
286 processDataLineBluestein(line_buf, line_len, good_composite, log_len,
a,
b, pos_j_base_powered);
293 Index offset = base_offset;
295 for (
int j = 0;
j < line_len; ++
j, offset += stride) {
296 buf[offset] = (FFTDir ==
FFT_FORWARD) ? line_buf[
j] : line_buf[
j] * div_factor;
301 if (!is_power_of_two) {
304 m_device.deallocate(pos_j_base_powered);
309 for (
Index i = 0;
i < m_size; ++
i) {
318 return !(
x & (
x - 1));
324 while (
i < 2 *
n - 1)
i *= 2;
330 while (
m >>= 1) log2m++;
337 scramble_FFT(line_buf, line_len);
338 compute_1D_Butterfly<FFTDir>(line_buf, line_len, log_len);
349 a[
i] =
data[
i] * numext::conj(pos_j_base_powered[
i]);
352 a[
i] =
data[
i] * pos_j_base_powered[
i];
361 b[
i] = pos_j_base_powered[
i];
364 b[
i] = numext::conj(pos_j_base_powered[
i]);
372 b[
i] = pos_j_base_powered[
m-
i];
375 b[
i] = numext::conj(pos_j_base_powered[
m-
i]);
380 compute_1D_Butterfly<FFT_FORWARD>(
a,
m, log_len);
383 compute_1D_Butterfly<FFT_FORWARD>(
b,
m, log_len);
390 compute_1D_Butterfly<FFT_REVERSE>(
a,
m, log_len);
399 data[
i] =
a[
i] * numext::conj(pos_j_base_powered[
i]);
402 data[
i] =
a[
i] * pos_j_base_powered[
i];
415 while (
m >= 2 &&
j >
m) {
441 data[0] = tmp[0] + tmp[2];
442 data[1] = tmp[1] + tmp[3];
443 data[2] = tmp[0] - tmp[2];
444 data[3] = tmp[1] - tmp[3];
468 tmp_2[0] = tmp_1[0] + tmp_1[2];
469 tmp_2[1] = tmp_1[1] + tmp_1[3];
470 tmp_2[2] = tmp_1[0] - tmp_1[2];
471 tmp_2[3] = tmp_1[1] - tmp_1[3];
472 tmp_2[4] = tmp_1[4] + tmp_1[6];
474 #define SQRT2DIV2 0.7071067811865476
484 data[0] = tmp_2[0] + tmp_2[4];
485 data[1] = tmp_2[1] + tmp_2[5];
486 data[2] = tmp_2[2] + tmp_2[6];
487 data[3] = tmp_2[3] + tmp_2[7];
488 data[4] = tmp_2[0] - tmp_2[4];
489 data[5] = tmp_2[1] - tmp_2[5];
490 data[6] = tmp_2[2] - tmp_2[6];
491 data[7] = tmp_2[3] - tmp_2[7];
500 const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
502 ? m_minus_sin_2_PI_div_n_LUT[n_power_of_2]
503 : -m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
523 data[
i + 1] += temp1;
526 data[
i + 2] += temp2;
529 data[
i + 3] += temp3;
538 compute_1D_Butterfly<Dir>(
data,
n / 2, n_power_of_2 - 1);
539 compute_1D_Butterfly<Dir>(
data +
n / 2,
n / 2, n_power_of_2 - 1);
540 butterfly_1D_merge<Dir>(
data,
n, n_power_of_2);
542 butterfly_8<Dir>(
data);
544 butterfly_4<Dir>(
data);
546 butterfly_2<Dir>(
data);
554 for (
int i = NumDims - 1;
i > omitted_dim; --
i) {
555 const Index partial_m_stride = m_strides[
i] / m_dimensions[omitted_dim];
556 const Index idx = index / partial_m_stride;
557 index -= idx * partial_m_stride;
558 result += idx * m_strides[
i];
563 for (
Index i = 0;
i < omitted_dim; ++
i) {
564 const Index partial_m_stride = m_strides[
i] / m_dimensions[omitted_dim];
565 const Index idx = index / partial_m_stride;
566 index -= idx * partial_m_stride;
567 result += idx * m_strides[
i];
576 Index result = base + offset * m_strides[omitted_dim] ;
#define EIGEN_ALWAYS_INLINE
#define EIGEN_DEVICE_FUNC
const internal::remove_all_t< typename XprType::Nested > & expression() const
Eigen::internal::nested< TensorFFTOp >::type Nested
OutputScalar CoeffReturnType
Eigen::NumTraits< Scalar >::Real RealScalar
TensorFFTOp(const XprType &expr, const FFT &fft)
Eigen::internal::traits< TensorFFTOp >::Scalar Scalar
Eigen::internal::traits< TensorFFTOp >::StorageKind StorageKind
std::conditional_t< FFTResultType==RealPart||FFTResultType==ImagPart, RealScalar, ComplexScalar > OutputScalar
Eigen::internal::traits< TensorFFTOp >::Index Index
std::complex< RealScalar > ComplexScalar
typename remove_all< T >::type remove_all_t
EIGEN_ALWAYS_INLINE T sin(const T &x)
EIGEN_ALWAYS_INLINE T cos(const T &x)
: TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_arg_op< typename Derived::Scalar >, const Derived > arg(const Eigen::ArrayBase< Derived > &x)
T operator()(const T &val) const
internal::packet_traits< Scalar >::type type
T operator()(const T &val) const
void processDataLineBluestein(ComplexScalar *line_buf, Index line_len, Index good_composite, Index log_len, ComplexScalar *a, ComplexScalar *b, const ComplexScalar *pos_j_base_powered)
internal::TensorBlockNotImplemented TensorBlock
const Device EIGEN_DEVICE_REF m_device
EvaluatorPointerType m_data
internal::traits< XprType > XprTraits
Storage::Type EvaluatorPointerType
void butterfly_4(ComplexScalar *data)
TensorEvaluator< ArgType, Device > m_impl
static Index getLog2(Index m)
TensorFFTOp< FFT, ArgType, FFTResultType, FFTDir > XprType
void evalToBuf(EvaluatorPointerType data)
static Index findGoodComposite(Index n)
const Dimensions & dimensions() const
EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
const FFT EIGEN_DEVICE_REF m_fft
PacketType< OutputScalar, Device >::type PacketReturnType
static bool isPowerOfTwo(Index x)
std::complex< RealScalar > ComplexScalar
DSizes< Index, NumDims > Dimensions
Index getIndexFromOffset(Index base, Index omitted_dim, Index offset) const
TensorEvaluator(const XprType &op, const Device &device)
void butterfly_1D_merge(ComplexScalar *data, Index n, Index n_power_of_2)
Eigen::NumTraits< Scalar >::Real RealScalar
EvaluatorPointerType data() const
Index getBaseOffsetFromIndex(Index index, Index omitted_dim) const
array< Index, NumDims > m_strides
void butterfly_2(ComplexScalar *data)
XprTraits::Scalar InputScalar
void processDataLineCooleyTukey(ComplexScalar *line_buf, Index line_len, Index log_len)
static void scramble_FFT(ComplexScalar *data, Index n)
bool evalSubExprsIfNeeded(EvaluatorPointerType data)
EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
TensorOpCost costPerCoeff(bool vectorized) const
std::conditional_t< FFTResultType==RealPart||FFTResultType==ImagPart, RealScalar, ComplexScalar > OutputScalar
void butterfly_8(ComplexScalar *data)
StorageMemory< CoeffReturnType, Device > Storage
void compute_1D_Butterfly(ComplexScalar *data, Index n, Index n_power_of_2)
TensorEvaluator< ArgType, Device >::Dimensions InputDimensions
OutputScalar CoeffReturnType
A cost model used to limit the number of threads used for evaluating tensor expression.
const Dimensions & dimensions() const
static constexpr int Layout
const Device EIGEN_DEVICE_REF m_device
Storage::Type EvaluatorPointerType
static constexpr int PacketSize
EvaluatorPointerType data() const
PacketType< CoeffReturnType, Device >::type PacketReturnType
EvaluatorPointerType m_data