30#pragma GCC system_header
34#include "../../math/constants.hpp"
43template <
typename _inputT,
typename _outputT,
size_t _rank,
int _cudaGPU = 0>
60template <
typename _inputT,
typename _outputT,
size_t _rank>
70 static const size_t rank =
_rank;
72 typedef Eigen::Array<inputT, -1, -1> eigenArrayInputT;
73 typedef Eigen::Array<outputT, -1, -1> eigenArrayOutputT;
75 typedef Eigen::Matrix<complexT, -1, -1> eigenMatrixT;
97 template <
size_t crank = _rank>
103 typename std::enable_if<crank == 1>::type * = 0 ) =
delete;
106 template <
size_t crank = _rank>
114 typename std::enable_if<crank == 2>::type * = 0 );
117 template <
size_t crank = _rank>
127 typename std::enable_if<crank == 3>::type * = 0 ) =
delete;
130 template <
size_t crank = _rank>
138 typename std::enable_if<crank == 2>::type * = 0 );
142 const eigenArrayInputT &
in )
const;
145template <
typename inputT,
typename outputT,
size_t rank>
146mdftT<inputT, outputT, rank, 0>::mdftT()
150template <
typename inputT,
typename outputT,
size_t rank>
151template <
size_t crank>
152mdftT<inputT, outputT, rank, 0>::mdftT(
153 int nx,
int ny,
int ndir, realT
xoff, realT
yoff, realT
osFac,
typename std::enable_if<crank == 2>::type * )
158template <
typename inputT,
typename outputT,
size_t rank>
159template <
size_t crank>
160void mdftT<inputT, outputT, rank, 0>::plan(
161 int nx,
int ny,
int ndir, realT
xOff, realT
yOff, realT
osFac,
typename std::enable_if<crank == 2>::type * )
172 throw std::invalid_argument(
"MDFT of non-square size is not implemented. nx must equal ny." );
176 m_dftR.resize( m_szX, m_szX );
177 m_dftC.resize( m_szX, m_szX );
180 int osN =
ceil( m_szX * m_osFac );
184 realT norm = 1.0 / ( m_szX * m_szY );
185 for(
int cc = 0;
cc < m_szY; ++
cc )
187 for(
int rr = 0;
rr < m_szX; ++
rr )
189 realT x =
static_cast<realT
>(
rr - m_xOff ) *
static_cast<realT
>(
cc );
193 x =
static_cast<realT
>(
rr ) *
static_cast<realT
>(
cc - m_yOff );
203 for(
int rr = 0;
rr < m_szX; ++
rr )
205 for(
int cc = 0;
cc < m_szY; ++
cc )
207 realT x =
static_cast<realT
>(
rr - m_xOff ) *
static_cast<realT
>(
cc );
211 x =
static_cast<realT
>(
rr ) *
static_cast<realT
>(
cc - m_yOff );
219template <
typename inputT,
typename outputT,
size_t rank>
220void mdftT<inputT, outputT, rank, 0>::operator()( eigenArrayOutputT &
out,
const eigenArrayInputT &
in )
const
222 out = ( m_dftR *
in.matrix() * m_dftC ).
array();
The fast Fourier transform interface.
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
std::complex< realT > complexT
The complex data type.
_realT realT
The real data type (_realT is actually defined in specializations).