30#include "../cuda/templateCufft.hpp"
31#include "../cuda/cudaPtr.hpp"
53template <
typename _inputT,
typename _outputT,
size_t _rank>
61 static const size_t rank =
_rank;
67 typedef typename cuda::cpp2cudaType<inputT>::cudaType cudaInT;
68 typedef typename cuda::cpp2cudaType<outputT>::cudaType cudaOutT;
70 typedef cuda::cudaPtr<inputT> cudaPtrInT;
72 typedef cuda::cudaPtr<outputT> cudaPtrOutT;
91 template <
int crank = _rank>
97 typename std::enable_if<crank == 1>::type * = 0 );
100 template <
int crank = _rank>
107 typename std::enable_if<crank == 2>::type * = 0 );
110 template <
int crank = _rank>
118 typename std::enable_if<crank == 3>::type * = 0 );
134 template <
int crank = _rank>
140 typename std::enable_if<crank == 1>::type * = 0 );
143 template <
int crank = _rank>
150 typename std::enable_if<crank == 2>::type * = 0 );
153 template <
int crank = _rank>
161 typename std::enable_if<crank == 3>::type * = 0 );
170 cudaPtrInT &
in )
const;
173template <
typename inputT,
typename outputT,
size_t rank>
174fftT<inputT, outputT, rank, 1>::fftT()
178template <
typename inputT,
typename outputT,
size_t rank>
180fftT<inputT, outputT, rank, 1>::fftT(
int nx,
ft::dir ndir,
bool inPlace,
typename std::enable_if<crank == 1>::type * )
182 static_assert(
crank == 2,
"only rank 2 is currently supported for cuda fftT" );
188template <
typename inputT,
typename outputT,
size_t rank>
190fftT<inputT, outputT, rank, 1>::fftT(
191 int nx,
int ny,
ft::dir ndir,
bool inPlace,
typename std::enable_if<crank == 2>::type * )
196template <
typename inputT,
typename outputT,
size_t rank>
198fftT<inputT, outputT, rank, 1>::fftT(
199 int nx,
int ny,
int nz,
ft::dir ndir,
bool inPlace,
typename std::enable_if<crank == 3>::type * )
201 static_assert(
crank == 2,
"only rank 2 is currently supported for cuda fftT" );
208template <
typename inputT,
typename outputT,
size_t rank>
209fftT<inputT, outputT, rank, 1>::~fftT()
214template <
typename inputT,
typename outputT,
size_t rank>
215void fftT<inputT, outputT, rank, 1>::destroyPlan()
229template <
typename inputT,
typename outputT,
size_t rank>
230ft::dir fftT<inputT, outputT, rank, 1>::direction()
235template <
typename inputT,
typename outputT,
size_t rank>
237void fftT<inputT, outputT, rank, 1>::plan(
int nx,
240 typename std::enable_if<crank == 1>::type * )
242 static_assert(
crank == 2,
"only rank 2 is currently supported for cuda fftT" );
245template <
typename inputT,
typename outputT,
size_t rank>
247void fftT<inputT, outputT, rank, 1>::plan(
248 int nx,
int ny,
ft::dir ndir,
bool inPlace,
typename std::enable_if<crank == 2>::type * )
261 mx::cuda::cufftPlan2d<cudaInT, cudaOutT>( &m_plan, nx, ny );
264template <
typename inputT,
typename outputT,
size_t rank>
266void fftT<inputT, outputT, rank, 1>::plan(
267 int nx,
int ny,
int nz,
ft::dir ndir,
bool inPlace,
typename std::enable_if<crank == 3>::type * )
269 static_assert(
crank == 2,
"only rank 2 is currently supported for cuda fftT" );
272template <
typename inputT,
typename outputT,
size_t rank>
273cufftResult fftT<inputT, outputT, rank, 1>::operator()( cudaOutT *
out, cudaInT *
in )
const
275 return mx::cuda::cufftExec<cudaInT, cudaOutT>( m_plan,
in,
out, m_cufftDirection );
278template <
typename inputT,
typename outputT,
size_t rank>
279cufftResult fftT<inputT, outputT, rank, 1>::operator()( cudaPtrOutT &
out, cudaPtrInT &
in )
const
281 return mx::cuda::cufftExec<cudaInT, cudaOutT>( m_plan,
in.data(),
out.data(), m_cufftDirection );
The Fast Fourier Transform interface.
dir
Directions of the Fourier Transform.
@ backward
Specifies the backward transform.
@ forward
Specifies the forward transform.
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
std::complex< realT > complexT
The complex data type.
_realT realT
The real data type (_realT is actually defined in specializations).