30#pragma GCC system_header
34#include "../../math/constants.hpp"
54template <
typename _inputT,
typename _outputT,
size_t _rank,
int _cudaGPU = 0>
87template <
typename _inputT,
typename _outputT,
size_t _rank>
97 static const size_t rank =
_rank;
99 typedef Eigen::Array<inputT, -1, -1> eigenArrayInT;
100 typedef Eigen::Array<outputT, -1, -1> eigenArrayOutT;
102 typedef Eigen::Matrix<complexT, -1, -1> eigenMatrixT;
125 template <
size_t crank = _rank>
132 typename std::enable_if<crank == 1>::type * = 0 ) =
delete;
135 template <
size_t crank = _rank>
145 typename std::enable_if<crank == 2>::type * = 0 );
148 template <
size_t crank = _rank>
161 typename std::enable_if<crank == 3>::type * = 0 ) =
delete;
164 template <
size_t crank = _rank>
174 typename std::enable_if<crank == 2>::type * = 0 );
178 const eigenArrayInT &
in )
const;
181template <
typename inputT,
typename outputT,
size_t rank>
182mftT<inputT, outputT, rank, 0>::mftT()
186template <
typename inputT,
typename outputT,
size_t rank>
187template <
size_t crank>
188mftT<inputT, outputT, rank, 0>::mftT(
189 int nx,
int ny,
dir ndir, realT
xoff, realT
yoff, realT
osFac,
typename std::enable_if<crank == 2>::type * )
194template <
typename inputT,
typename outputT,
size_t rank>
195template <
size_t crank>
196void mftT<inputT, outputT, rank, 0>::plan(
197 int nx,
int ny,
dir ndir, realT
xOff, realT
yOff, realT
osFac,
typename std::enable_if<crank == 2>::type * )
199 if(m_szX == nx && m_szY == ny && m_dir ==
ndir && m_xOff ==
xOff && m_yOff ==
yOff && m_osFac ==
osFac)
213 throw std::invalid_argument(
"MFT of non-square size is not implemented. nx must equal ny." );
217 m_dftR.resize( m_szX, m_szX );
218 m_dftC.resize( m_szX, m_szX );
221 realT
osN = m_szX * m_osFac;
226 realT norm = 1.0 / ( m_szX * m_szY );
228 for(
int cc = 0;
cc < m_szY; ++
cc )
232 for(
int rr = 0;
rr < m_szX; ++
rr )
234 realT x = (
rr - m_xOff ) *
ccx /
osN;
251 for(
int cc = 0;
cc < m_szY; ++
cc )
254 if(
ccx > m_szY / 2 )
255 ccx = -1 * ( m_szY -
ccx );
257 for(
int rr = 0;
rr < m_szX; ++
rr )
260 if(
rrx > m_szX / 2 )
261 rrx = -1 * ( m_szX -
rrx );
263 realT x =
rrx * (
cc - m_xOff ) /
osN;
275template <
typename inputT,
typename outputT,
size_t rank>
276void mftT<inputT, outputT, rank, 0>::operator()( eigenArrayOutT &
out,
const eigenArrayInT &
in )
const
278 out = ( m_dftR *
in.matrix() * m_dftC ).
array();
dir
Directions of the Fourier Transform.
@ forward
Specifies the forward transform.
T sign(T x)
The sign function.
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
std::complex< realT > complexT
The complex data type.
_realT realT
The real data type (_realT is actually defined in specializations).