mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
fftTcuda.hpp
1/** \file fftT.hpp
2 * \brief The Fast Fourier Transform interface
3 * \ingroup ft_files
4 * \author Jared R. Males (jaredmales@gmail.com)
5 *
6 */
7
8//***********************************************************************//
9// Copyright 2015-2025 Jared R. Males (jaredmales@gmail.com)
10//
11// This file is part of mxlib.
12//
13// mxlib is free software: you can redistribute it and/or modify
14// it under the terms of the GNU General Public License as published by
15// the Free Software Foundation, either version 3 of the License, or
16// (at your option) any later version.
17//
18// mxlib is distributed in the hope that it will be useful,
19// but WITHOUT ANY WARRANTY; without even the implied warranty of
20// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21// GNU General Public License for more details.
22//
23// You should have received a copy of the GNU General Public License
24// along with mxlib. If not, see <http://www.gnu.org/licenses/>.
25//***********************************************************************//
26
27#ifndef fftTcuda_hpp
28#define fftTcuda_hpp
29
30#include "../cuda/templateCufft.hpp"
31#include "../cuda/cudaPtr.hpp"
32
33#include "fftT.hpp"
34
35namespace mx
36{
37namespace math
38{
39namespace ft
40{
41
42/// Fast Fourier Transforms using the \ref cuda "CUDA library"
43/** The \ref cufft_templates "CUFFT Templates" type resolution system is used to allow the compiler
44 * to access the right plan and types for the transforms based on inputT and outputT.
45 *
46 *
47 * \tparam _inputT is the input type of the transform, can be either real or complex
48 * \tparam _outputT is the output type of the transform, can be either real or complex
49 * \tparam _rank is the rank of the transform. Limited to 2 for now.
50 *
51 * \ingroup fft
52 */
53template <typename _inputT, typename _outputT, size_t _rank>
54class fftT<_inputT, _outputT, _rank, 1>
55{
56
57 public:
58 typedef _inputT inputT;
59 typedef _outputT outputT;
60
61 static const size_t rank = _rank;
62
63 typedef typename fftwTypeSpec<inputT, outputT>::realT realT;
64
65 typedef typename fftwTypeSpec<inputT, outputT>::complexT complexT;
66
67 typedef typename cuda::cpp2cudaType<inputT>::cudaType cudaInT;
68 typedef typename cuda::cpp2cudaType<outputT>::cudaType cudaOutT;
69
70 typedef cuda::cudaPtr<inputT> cudaPtrInT;
71
72 typedef cuda::cudaPtr<outputT> cudaPtrOutT;
73
74 typedef cufftHandle planT;
75
76 protected:
77 dir m_direction{ dir::forward }; ///< Direction of this FFT, either dir::forward (default) or dir::backward
78 int m_cufftDirection{ CUFFT_FORWARD }; ///< Direction to pass to CUFFT routine. Kept synchronized with m_direction.
79
80 int m_szX{ 0 }; ///< Size of the x dimension
81 int m_szY{ 0 }; ///< Size of the y dimension
82 int m_szZ{ 0 }; ///< size of the z dimension
83
84 planT m_plan {0}; ///< The cufft handle object.
85
86 public:
87 /// Default c'tor
89
90 /// Constructor for rank 1 FFT.
91 template <int crank = _rank>
92 fftT( int nx, ///< [in] the desired size of the FFT
93 dir ndir = dir::forward, ///< [in] [optional] direction of this FFT, either dir::forward (default) or
94 ///< dir::backward
95 bool inPlace = false, /**< [in] [optional] whether or not this is an in-place transform.
96 Default is false, out-of-place.*/
97 typename std::enable_if<crank == 1>::type * = 0 );
98
99 /// Constructor for rank 2 FFT.
100 template <int crank = _rank>
101 fftT( int nx, ///< [in] the desired x size of the FFT
102 int ny, ///< [in] the desired y size of the FFT
103 dir ndir = dir::forward, /**< [in] [optional] direction of this FFT, either dir::forward
104 (default) or dir::backward */
105 bool inPlace = false, /**< [in] [optional] whether or not this is an in-place transform.
106 Default is false, out-of-place. */
107 typename std::enable_if<crank == 2>::type * = 0 );
108
109 /// Constructor for rank 3 FFT.
110 template <int crank = _rank>
111 fftT( int nx, ///< [in] the desired x size of the FFT
112 int ny, ///< [in] the desired y size of the FFT
113 int nz, ///< [in] the desired z size of the FFT
114 dir ndir = dir::forward, /**< [in] [optional] direction of this FFT, either dir::forward (default)
115 or dir::backward*/
116 bool inPlace = false, /**< [in] [optional] whether or not this is an in-place
117 transform. Default is false, out-of-place. */
118 typename std::enable_if<crank == 3>::type * = 0 );
119
120 /// Destructor
122
123 /// Destroy (de-allocate) the plan
125
126 /// Get the direction of this FFT
127 /** The direction is either dir::forward or dir::backward.
128 *
129 * \returns the current value of m_direction.
130 */
132
133 /// Planning routine for rank 1 transforms.
134 template <int crank = _rank>
135 void plan( int nx, ///< [in] the desired size of the FFT
136 ft::dir ndir = dir::forward, /**< [in] [optional] direction of this FFT, either dir::forward (default)
137 or dir::backward */
138 bool inPlace = false, /**< [in] [optional] whether or not this is an in-place transform.
139 Default is false, out-of-place. */
140 typename std::enable_if<crank == 1>::type * = 0 );
141
142 /// Planning routine for rank 2 transforms.
143 template <int crank = _rank>
144 void plan( int nx, ///< [in] the desired x size of the FFT
145 int ny, ///< [in] the desired y size of the FFT
146 ft::dir ndir = dir::forward, /**< [in] [optional] direction of this FFT, either dir::forward (default)
147 or dir::backward */
148 bool inPlace = false, /**< [in] [optional] whether or not this is an in-place transform.
149 Default is false, out-of-place. */
150 typename std::enable_if<crank == 2>::type * = 0 );
151
152 /// Planning routine for rank 3 transforms.
153 template <int crank = _rank>
154 void plan( int nx, ///< [in] the desired x size of the FFT
155 int ny, ///< [in] the desired y size of the FFT
156 int nz, ///< [in] the desired z size of the FFT
157 ft::dir ndir = dir::forward, /**< [in] [optional] direction of this FFT, either dir::forward
158 (default) or dir::backward */
159 bool inPlace = false, /**< [in] [optional] whether or not this is an in-place transform.
160 Default is false, out-of-place. */
161 typename std::enable_if<crank == 3>::type * = 0 );
162
163 /// Conduct the FFT
164 cufftResult operator()( cudaOutT *out, ///< [out] [device] the output of the FFT, must be pre-allocated
165 cudaInT *in ///< [in] [device] the input to the FFT
166 ) const;
167
168 /// Conduct the MFT
169 cufftResult operator()( cudaPtrOutT &out, /**< [out] the output of the DFT */
170 cudaPtrInT &in /**< [in] the input to the DFT */ ) const;
171};
172
173template <typename inputT, typename outputT, size_t rank>
174fftT<inputT, outputT, rank, 1>::fftT()
175{
176}
177
178template <typename inputT, typename outputT, size_t rank>
179template <int crank>
180fftT<inputT, outputT, rank, 1>::fftT( int nx, ft::dir ndir, bool inPlace, typename std::enable_if<crank == 1>::type * )
181{
182 static_assert( crank == 2, "only rank 2 is currently supported for cuda fftT" );
183 // m_direction = ndir;
184
185 // plan( nx, ndir, inPlace );
186}
187
188template <typename inputT, typename outputT, size_t rank>
189template <int crank>
190fftT<inputT, outputT, rank, 1>::fftT(
191 int nx, int ny, ft::dir ndir, bool inPlace, typename std::enable_if<crank == 2>::type * )
192{
193 plan( nx, ny, ndir, inPlace );
194}
195
196template <typename inputT, typename outputT, size_t rank>
197template <int crank>
198fftT<inputT, outputT, rank, 1>::fftT(
199 int nx, int ny, int nz, ft::dir ndir, bool inPlace, typename std::enable_if<crank == 3>::type * )
200{
201 static_assert( crank == 2, "only rank 2 is currently supported for cuda fftT" );
202
203 // m_direction = ndir;
204
205 // plan( nx, ny, nz, ndir, inPlace );
206}
207
208template <typename inputT, typename outputT, size_t rank>
209fftT<inputT, outputT, rank, 1>::~fftT()
210{
211 destroyPlan();
212}
213
214template <typename inputT, typename outputT, size_t rank>
215void fftT<inputT, outputT, rank, 1>::destroyPlan()
216{
217 if( m_plan )
218 {
219 cufftDestroy( m_plan );
220 }
221
222 m_plan = 0;
223
224 m_szX = 0;
225 m_szY = 0;
226 m_szZ = 0;
227}
228
229template <typename inputT, typename outputT, size_t rank>
230ft::dir fftT<inputT, outputT, rank, 1>::direction()
231{
232 return m_direction;
233}
234
235template <typename inputT, typename outputT, size_t rank>
236template <int crank>
237void fftT<inputT, outputT, rank, 1>::plan( int nx,
239 bool inPlace,
240 typename std::enable_if<crank == 1>::type * )
241{
242 static_assert( crank == 2, "only rank 2 is currently supported for cuda fftT" );
243}
244
245template <typename inputT, typename outputT, size_t rank>
246template <int crank>
247void fftT<inputT, outputT, rank, 1>::plan(
248 int nx, int ny, ft::dir ndir, bool inPlace, typename std::enable_if<crank == 2>::type * )
249{
250 m_direction = ndir;
251
252 if( m_direction == dir::backward )
253 {
254 m_cufftDirection = CUFFT_INVERSE;
255 }
256 else
257 {
258 m_cufftDirection = CUFFT_FORWARD;
259 }
260
261 mx::cuda::cufftPlan2d<cudaInT, cudaOutT>( &m_plan, nx, ny );
262}
263
264template <typename inputT, typename outputT, size_t rank>
265template <int crank>
266void fftT<inputT, outputT, rank, 1>::plan(
267 int nx, int ny, int nz, ft::dir ndir, bool inPlace, typename std::enable_if<crank == 3>::type * )
268{
269 static_assert( crank == 2, "only rank 2 is currently supported for cuda fftT" );
270}
271
272template <typename inputT, typename outputT, size_t rank>
273cufftResult fftT<inputT, outputT, rank, 1>::operator()( cudaOutT *out, cudaInT *in ) const
274{
275 return mx::cuda::cufftExec<cudaInT, cudaOutT>( m_plan, in, out, m_cufftDirection );
276}
277
278template <typename inputT, typename outputT, size_t rank>
279cufftResult fftT<inputT, outputT, rank, 1>::operator()( cudaPtrOutT &out, cudaPtrInT &in ) const
280{
281 return mx::cuda::cufftExec<cudaInT, cudaOutT>( m_plan, in.data(), out.data(), m_cufftDirection );
282}
283
284} // namespace ft
285} // namespace math
286} // namespace mx
287
288#endif // fftTcuda_hpp
fftT(int nx, dir ndir=dir::forward, bool inPlace=false, typename std::enable_if< crank==1 >::type *=0)
Constructor for rank 1 FFT.
cufftResult operator()(cudaPtrOutT &out, cudaPtrInT &in) const
Conduct the MFT.
void plan(int nx, ft::dir ndir=dir::forward, bool inPlace=false, typename std::enable_if< crank==1 >::type *=0)
Planning routine for rank 1 transforms.
void destroyPlan()
Destroy (de-allocate) the plan.
fftT(int nx, int ny, int nz, dir ndir=dir::forward, bool inPlace=false, typename std::enable_if< crank==3 >::type *=0)
Constructor for rank 3 FFT.
void plan(int nx, int ny, int nz, ft::dir ndir=dir::forward, bool inPlace=false, typename std::enable_if< crank==3 >::type *=0)
Planning routine for rank 3 transforms.
fftT(int nx, int ny, dir ndir=dir::forward, bool inPlace=false, typename std::enable_if< crank==2 >::type *=0)
Constructor for rank 2 FFT.
ft::dir direction()
Get the direction of this FFT.
cufftResult operator()(cudaOutT *out, cudaInT *in) const
Conduct the FFT.
void plan(int nx, int ny, ft::dir ndir=dir::forward, bool inPlace=false, typename std::enable_if< crank==2 >::type *=0)
Planning routine for rank 2 transforms.
The Fast Fourier Transform interface.
dir
Directions of the Fourier Transform.
Definition ftTypes.hpp:41
@ backward
Specifies the backward transform.
@ forward
Specifies the forward transform.
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
The mxlib c++ namespace.
Definition mxlib.hpp:37
std::complex< realT > complexT
The complex data type.
_realT realT
The real data type (_realT is actually defined in specializations).