Line data Source code
1 : /** \file mftT.hpp
2 : * \brief The Matrix Fourier Transform interface
3 : * \ingroup ft_files
4 : * \author Jared R. Males (jaredmales@gmail.com)
5 : *
6 : */
7 :
8 : //***********************************************************************//
9 : // Copyright 2024-2025 Jared R. Males (jaredmales@gmail.com)
10 : //
11 : // This file is part of mxlib.
12 : //
13 : // mxlib is free software: you can redistribute it and/or modify
14 : // it under the terms of the GNU General Public License as published by
15 : // the Free Software Foundation, either version 3 of the License, or
16 : // (at your option) any later version.
17 : //
18 : // mxlib is distributed in the hope that it will be useful,
19 : // but WITHOUT ANY WARRANTY; without even the implied warranty of
20 : // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 : // GNU General Public License for more details.
22 : //
23 : // You should have received a copy of the GNU General Public License
24 : // along with mxlib. If not, see <http://www.gnu.org/licenses/>.
25 : //***********************************************************************//
26 :
27 : #ifndef mftT_hpp
28 : #define mftT_hpp
29 :
30 : #pragma GCC system_header
31 : #include <Eigen/Dense>
32 :
33 : #include "ftTypes.hpp"
34 : #include "../../math/constants.hpp"
35 :
36 : /** \addtogroup mft
37 : *
38 : * The Matrix Fourier Transform is the <a href="https://en.wikipedia.org/wiki/DFT_matrix">
39 : * matrix-multiplication implementation of the
40 : * Discrete Fourier Transform</a>. It is slower than the FFT given the same size matrix, e.g.
41 : * <a href="https://pages.hmc.edu/ruye/e161/lectures/fourier/node11.html">in 2D</a>
42 : * \f$O(N^3)\f$ vs \f$O(N^2\ln(N))\f$. However, compared to zero-padding,
43 : * it can provide advantages in both speed
44 : * and memory required for problems where oversampling is needed but only over a small region
45 : * of the output.
46 : */
47 : namespace mx
48 : {
49 : namespace math
50 : {
51 : namespace ft
52 : {
53 :
54 : template <typename _inputT, typename _outputT, size_t _rank, int _cudaGPU = 0>
55 : class mftT;
56 :
57 : /// Matrix Fourier Transforms
58 : /** Calculates the Discrete Fourier Transform (DFT) using matrix multiplication. This
59 : * is normally much less efficient than the FFT, but for large oversampling (padding)
60 : * the Matrix FT (MFT) will be much more space efficient and faster at the cost of
61 : * "field of view".
62 : *
63 : * This interface is modeled after the \ref fftT<_inputT, _outputT, _rank, 0> "fftT" interface to fftw, and is interoperable with it.
64 : * That means that using, e.g., mftT (unshifted, not-oversampled) for the forward transform
65 : * and \ref fftT<_inputT, _outputT, _rank, 0> "fftT" for the backward
66 : * transform is equivalent to using one or the other for both. Note that
67 : * the MFT is normalized as in fftw, so by 1/N on the forward and by 1 on the backward.
68 : *
69 : * Oversampling is optionally included in the transform, which is equivalent to zero padding
70 : * in terms of increased resolution \cite soummer_2007.
71 : * The cost is that the output domain is truncated by the oversampling factor. I.e. if we
72 : * oversample by a factor of 10, only 1/10th of the output transform is available.
73 : *
74 : * The output can be shifted as part of the MFT calculation, which is implemented similarly
75 : * to <a href="https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation">
76 : * this matlab code</a>.
77 : *
78 : * Note that when either oversampling or shifting is done on a forward (backward) transform,
79 : * a subsequent backward (forward) transform will not in general be the inverse.
80 : *
81 : * \tparam inputT is the input type of the transform, only complex types are suppored by mftT
82 : * \tparam outputT is the output type of the transform, only complex types are suppored by mftT
83 : * \tparam _rank is the rank of the transform. Currently only rank 2 is implemented.
84 : *
85 : * \ingroup mft
86 : */
87 : template <typename _inputT, typename _outputT, size_t _rank>
88 : class mftT<_inputT, _outputT, _rank, 0>
89 : {
90 : typedef _inputT inputT;
91 : typedef _outputT outputT;
92 :
93 : typedef typename fftwTypeSpec<inputT, outputT>::realT realT;
94 :
95 : typedef typename fftwTypeSpec<_inputT, _outputT>::complexT complexT;
96 :
97 : static const size_t rank = _rank;
98 :
99 : typedef Eigen::Array<inputT, -1, -1> eigenArrayInT;
100 : typedef Eigen::Array<outputT, -1, -1> eigenArrayOutT;
101 :
102 : typedef Eigen::Matrix<complexT, -1, -1> eigenMatrixT;
103 :
104 : protected:
105 : dir m_dir{ dir::forward }; /**< Direction of this MFT, either dir::forward (default)
106 : or dir::backward */
107 :
108 : int m_szX{ 0 }; ///< Size of the x dimension
109 : int m_szY{ 0 }; ///< Size of the y dimension
110 : int m_szZ{ 0 }; ///< size of the z dimension
111 :
112 : float m_osFac{ 1 }; ///< The oversampling factor
113 :
114 : realT m_xOff{ 0 }; ///< The offset in the rows direction for the center of the DFT.
115 : realT m_yOff{ 0 }; ///< The offset in the columns direction for the center of the DFT.
116 : public:
117 : eigenMatrixT m_dftR; ///< DFT matrix for the rows
118 : eigenMatrixT m_dftC; ///< DFT matrix for the columnss
119 :
120 : public:
121 : /// Default c'tor
122 : mftT();
123 :
124 : /// Constructor for rank 1 MFT.
125 : template <size_t crank = _rank>
126 : mftT( int nx, ///< [in] the desired size of the MFT
127 : dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
128 : (default) or dir::backward */
129 : realT xOff = 0, /**< [in] [optional] the x offset of the center of the
130 : transformed array. Default 0.*/
131 : realT osFac = 1.0, /**< [in] [optional] the oversampling factor. Default 1. */
132 : typename std::enable_if<crank == 1>::type * = 0 ) = delete;
133 :
134 : /// Constructor for rank 2 MFT.
135 : template <size_t crank = _rank>
136 : mftT( int nx, ///< [in] the desired x size of the MFT
137 : int ny, ///< [in] the desired y size of the MFT
138 : dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
139 : (default) or dir::backward */
140 : realT xOff = 0, /**< [in] [optional] the x offset of the center of the
141 : transformed array. Default 0.*/
142 : realT yOff = 0, /**< [in] [optional] the y offset of the center of the
143 : transformed array Default 0.*/
144 : realT osFac = 1.0, /**< [in] [optional] the oversampling factor. Default 1. */
145 : typename std::enable_if<crank == 2>::type * = 0 );
146 :
147 : /// Constructor for rank 3 MFT.
148 : template <size_t crank = _rank>
149 : mftT( int nx, ///< [in] the desired x size of the MFT
150 : int ny, ///< [in] the desired y size of the MFT
151 : int nz, ///< [in] the desired z size of the MFT
152 : dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
153 : (default) or dir::backward */
154 : realT xOff = 0, /**< [in] [optional] the x offset of the center of the
155 : transformed array. Default 0.*/
156 : realT yOff = 0, /**< [in] [optional] the y offset of the center of the
157 : transformed array Default 0.*/
158 : realT zOff = 0, /**< [in] [optional] the z offset of the center of the
159 : transformed array Default 0.*/
160 : realT osFac = 1.0, /**< [in] [optional] the oversampling factor. Default 1. */
161 : typename std::enable_if<crank == 3>::type * = 0 ) = delete;
162 :
163 : /// Planning routine for rank 2 transforms.
164 : template <size_t crank = _rank>
165 : void plan( int nx, ///< [in] the desired x size of the MFT
166 : int ny, ///< [in] the desired y size of the MFT
167 : dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
168 : (default) or dir::backward */
169 : realT xOff = 0, /**< [in] [optional] the x offset of the center of the
170 : transformed array. Default 0.*/
171 : realT yOff = 0, /**< [in] [optional] the y offset of the center of the
172 : transformed array Default 0.*/
173 : realT osFac = 1.0, /**< [in] [optional] the oversampling factor. Default 1. */
174 : typename std::enable_if<crank == 2>::type * = 0 );
175 :
176 : /// Conduct the MFT
177 : void operator()( eigenArrayOutT &out, /**< [out] the output of the DFT */
178 : const eigenArrayInT &in /**< [in] the input to the DFT */ ) const;
179 : };
180 :
181 : template <typename inputT, typename outputT, size_t rank>
182 12 : mftT<inputT, outputT, rank, 0>::mftT()
183 : {
184 12 : }
185 :
186 : template <typename inputT, typename outputT, size_t rank>
187 : template <size_t crank>
188 : mftT<inputT, outputT, rank, 0>::mftT(
189 : int nx, int ny, dir ndir, realT xoff, realT yoff, realT osFac, typename std::enable_if<crank == 2>::type * )
190 : {
191 : plan( nx, ny, ndir, xoff, yoff, osFac );
192 : }
193 :
194 : template <typename inputT, typename outputT, size_t rank>
195 : template <size_t crank>
196 0 : void mftT<inputT, outputT, rank, 0>::plan(
197 : int nx, int ny, dir ndir, realT xOff, realT yOff, realT osFac, typename std::enable_if<crank == 2>::type * )
198 : {
199 0 : if(m_szX == nx && m_szY == ny && m_dir == ndir && m_xOff == xOff && m_yOff == yOff && m_osFac == osFac)
200 : {
201 0 : return;
202 : }
203 :
204 0 : m_szX = nx;
205 0 : m_szY = ny;
206 0 : m_dir = ndir;
207 0 : m_xOff = xOff;
208 0 : m_yOff = yOff;
209 0 : m_osFac = osFac;
210 :
211 0 : if( m_szX != m_szY )
212 : {
213 0 : throw std::invalid_argument( "MFT of non-square size is not implemented. nx must equal ny." );
214 : }
215 :
216 : // These should depend on szX too.
217 0 : m_dftR.resize( m_szX, m_szX );
218 0 : m_dftC.resize( m_szX, m_szX );
219 :
220 : // There is probably an osNx and an osNy?
221 0 : realT osN = m_szX * m_osFac;
222 :
223 0 : if( m_dir == dir::forward )
224 : {
225 0 : realT sign = -1;
226 0 : realT norm = 1.0 / ( m_szX * m_szY );
227 :
228 0 : for( int cc = 0; cc < m_szY; ++cc )
229 : {
230 0 : realT ccx = cc;
231 :
232 0 : for( int rr = 0; rr < m_szX; ++rr )
233 : {
234 0 : realT x = ( rr - m_xOff ) * ccx / osN;
235 :
236 0 : m_dftR( rr, cc ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
237 :
238 0 : realT rrx = rr;
239 :
240 0 : x = rrx * ( cc - m_yOff ) / osN;
241 :
242 0 : m_dftC( rr, cc ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
243 : }
244 : }
245 : }
246 : else
247 : {
248 0 : realT sign = +1;
249 0 : realT norm = 1.0;
250 :
251 0 : for( int cc = 0; cc < m_szY; ++cc )
252 : {
253 0 : realT ccx = cc;
254 0 : if( ccx > m_szY / 2 )
255 0 : ccx = -1 * ( m_szY - ccx );
256 :
257 0 : for( int rr = 0; rr < m_szX; ++rr )
258 : {
259 0 : realT rrx = rr;
260 0 : if( rrx > m_szX / 2 )
261 0 : rrx = -1 * ( m_szX - rrx );
262 :
263 0 : realT x = rrx * ( cc - m_xOff ) / osN;
264 :
265 0 : m_dftR( cc, rr ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
266 :
267 0 : x = ( rr - m_yOff ) * ccx / osN;
268 :
269 0 : m_dftC( cc, rr ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
270 : }
271 : }
272 : }
273 : }
274 :
275 : template <typename inputT, typename outputT, size_t rank>
276 0 : void mftT<inputT, outputT, rank, 0>::operator()( eigenArrayOutT &out, const eigenArrayInT &in ) const
277 : {
278 0 : out = ( m_dftR * in.matrix() * m_dftC ).array();
279 0 : }
280 :
281 : } // namespace ft
282 : } // namespace math
283 : } // namespace mx
284 :
285 : #endif // mdft_hpp
|