LCOV - code coverage report
Current view: top level - math/ft - mftT.hpp (source / functions) Coverage Total Hit
Test: mxlib Lines: 4.5 % 44 2
Test Date: 2026-02-19 16:58:26 Functions: 33.3 % 3 1

            Line data    Source code
       1              : /** \file mftT.hpp
       2              :  * \brief The Matrix Fourier Transform interface
       3              :  * \ingroup ft_files
       4              :  * \author Jared R. Males (jaredmales@gmail.com)
       5              :  *
       6              :  */
       7              : 
       8              : //***********************************************************************//
       9              : // Copyright 2024-2025 Jared R. Males (jaredmales@gmail.com)
      10              : //
      11              : // This file is part of mxlib.
      12              : //
      13              : // mxlib is free software: you can redistribute it and/or modify
      14              : // it under the terms of the GNU General Public License as published by
      15              : // the Free Software Foundation, either version 3 of the License, or
      16              : // (at your option) any later version.
      17              : //
      18              : // mxlib is distributed in the hope that it will be useful,
      19              : // but WITHOUT ANY WARRANTY; without even the implied warranty of
      20              : // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      21              : // GNU General Public License for more details.
      22              : //
      23              : // You should have received a copy of the GNU General Public License
      24              : // along with mxlib.  If not, see <http://www.gnu.org/licenses/>.
      25              : //***********************************************************************//
      26              : 
      27              : #ifndef mftT_hpp
      28              : #define mftT_hpp
      29              : 
      30              : #pragma GCC system_header
      31              : #include <Eigen/Dense>
      32              : 
      33              : #include "ftTypes.hpp"
      34              : #include "../../math/constants.hpp"
      35              : 
      36              : /** \addtogroup mft
      37              :  *
      38              :  * The Matrix Fourier Transform is the <a href="https://en.wikipedia.org/wiki/DFT_matrix">
      39              :  * matrix-multiplication implementation of the
      40              :  * Discrete Fourier Transform</a>.  It is slower than the FFT given the same size matrix, e.g.
      41              :  * <a href="https://pages.hmc.edu/ruye/e161/lectures/fourier/node11.html">in 2D</a>
      42              :  * \f$O(N^3)\f$ vs \f$O(N^2\ln(N))\f$.  However, compared to zero-padding,
      43              :  * it can provide advantages in both speed
      44              :  * and memory required for problems where oversampling is needed but only over a small region
      45              :  * of the output.
      46              :  */
      47              : namespace mx
      48              : {
      49              : namespace math
      50              : {
      51              : namespace ft
      52              : {
      53              : 
      54              : template <typename _inputT, typename _outputT, size_t _rank, int _cudaGPU = 0>
      55              : class mftT;
      56              : 
      57              : /// Matrix Fourier Transforms
      58              : /** Calculates the Discrete Fourier Transform (DFT) using matrix multiplication.  This
      59              :  *  is normally much less efficient than the FFT, but for large oversampling (padding)
      60              :  *  the Matrix FT (MFT) will be much more space efficient and faster at the cost of
      61              :  *  "field of view".
      62              :  *
      63              :  *  This interface is modeled after the \ref fftT<_inputT, _outputT, _rank, 0> "fftT" interface to fftw, and is interoperable with it.
      64              :  *  That means that using, e.g., mftT (unshifted, not-oversampled) for the forward transform
      65              :  *  and \ref fftT<_inputT, _outputT, _rank, 0> "fftT" for the backward
      66              :  *  transform is equivalent to using one or the other for both. Note that
      67              :  *  the MFT is normalized as in fftw, so by 1/N on the forward and by 1 on the backward.
      68              :  *
      69              :  *  Oversampling is optionally included in the transform, which is equivalent to zero padding
      70              :  *  in terms of increased resolution \cite soummer_2007.
      71              :  *  The cost is that the output domain is truncated by the oversampling factor.  I.e. if we
      72              :  *  oversample by a factor of 10, only 1/10th of the output transform is available.
      73              :  *
      74              :  *  The output can be shifted as part of the MFT calculation, which is implemented similarly
      75              :  *  to <a href="https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation">
      76              :  *  this matlab code</a>.
      77              :  *
      78              :  *  Note that when either oversampling or shifting is done on a forward (backward) transform,
      79              :  *  a subsequent backward (forward) transform will not in general be the inverse.
      80              :  *
      81              :  *  \tparam inputT is the input type of the transform, only complex types are suppored by mftT
      82              :  *  \tparam outputT is the output type of the transform, only complex types are suppored by mftT
      83              :  *  \tparam _rank is the rank of the transform. Currently only rank 2 is implemented.
      84              :  *
      85              :  *  \ingroup mft
      86              :  */
      87              : template <typename _inputT, typename _outputT, size_t _rank>
      88              : class mftT<_inputT, _outputT, _rank, 0>
      89              : {
      90              :     typedef _inputT inputT;
      91              :     typedef _outputT outputT;
      92              : 
      93              :     typedef typename fftwTypeSpec<inputT, outputT>::realT realT;
      94              : 
      95              :     typedef typename fftwTypeSpec<_inputT, _outputT>::complexT complexT;
      96              : 
      97              :     static const size_t rank = _rank;
      98              : 
      99              :     typedef Eigen::Array<inputT, -1, -1> eigenArrayInT;
     100              :     typedef Eigen::Array<outputT, -1, -1> eigenArrayOutT;
     101              : 
     102              :     typedef Eigen::Matrix<complexT, -1, -1> eigenMatrixT;
     103              : 
     104              :   protected:
     105              :     dir m_dir{ dir::forward }; /**< Direction of this MFT, either dir::forward (default)
     106              :                                      or dir::backward */
     107              : 
     108              :     int m_szX{ 0 }; ///< Size of the x dimension
     109              :     int m_szY{ 0 }; ///< Size of the y dimension
     110              :     int m_szZ{ 0 }; ///< size of the z dimension
     111              : 
     112              :     float m_osFac{ 1 }; ///< The oversampling factor
     113              : 
     114              :     realT m_xOff{ 0 }; ///< The offset in the rows direction for the center of the DFT.
     115              :     realT m_yOff{ 0 }; ///< The offset in the columns direction for the center of the DFT.
     116              :   public:
     117              :     eigenMatrixT m_dftR; ///< DFT matrix for the rows
     118              :     eigenMatrixT m_dftC; ///< DFT matrix for the columnss
     119              : 
     120              :   public:
     121              :     /// Default c'tor
     122              :     mftT();
     123              : 
     124              :     /// Constructor for rank 1 MFT.
     125              :     template <size_t crank = _rank>
     126              :     mftT( int nx,                   ///< [in] the desired size of the MFT
     127              :           dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
     128              :                                                          (default) or dir::backward */
     129              :           realT xOff = 0,           /**< [in] [optional] the x offset of the center of the
     130              :                                                          transformed array. Default 0.*/
     131              :           realT osFac = 1.0,        /**< [in] [optional] the oversampling factor.  Default 1. */
     132              :           typename std::enable_if<crank == 1>::type * = 0 ) = delete;
     133              : 
     134              :     /// Constructor for rank 2 MFT.
     135              :     template <size_t crank = _rank>
     136              :     mftT( int nx,                   ///< [in] the desired x size of the MFT
     137              :           int ny,                   ///< [in] the desired y size of the MFT
     138              :           dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
     139              :                                                          (default) or dir::backward */
     140              :           realT xOff = 0,           /**< [in] [optional] the x offset of the center of the
     141              :                                                          transformed array. Default 0.*/
     142              :           realT yOff = 0,           /**< [in] [optional] the y offset of the center of the
     143              :                                                          transformed array  Default 0.*/
     144              :           realT osFac = 1.0,        /**< [in] [optional] the oversampling factor.  Default 1. */
     145              :           typename std::enable_if<crank == 2>::type * = 0 );
     146              : 
     147              :     /// Constructor for rank 3 MFT.
     148              :     template <size_t crank = _rank>
     149              :     mftT( int nx,                   ///< [in] the desired x size of the MFT
     150              :           int ny,                   ///< [in] the desired y size of the MFT
     151              :           int nz,                   ///< [in] the desired z size of the MFT
     152              :           dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
     153              :                                                          (default) or dir::backward */
     154              :           realT xOff = 0,           /**< [in] [optional] the x offset of the center of the
     155              :                                                          transformed array. Default 0.*/
     156              :           realT yOff = 0,           /**< [in] [optional] the y offset of the center of the
     157              :                                                          transformed array  Default 0.*/
     158              :           realT zOff = 0,           /**< [in] [optional] the z offset of the center of the
     159              :                                                          transformed array  Default 0.*/
     160              :           realT osFac = 1.0,        /**< [in] [optional] the oversampling factor.  Default 1. */
     161              :           typename std::enable_if<crank == 3>::type * = 0 ) = delete;
     162              : 
     163              :     /// Planning routine for rank 2 transforms.
     164              :     template <size_t crank = _rank>
     165              :     void plan( int nx,                   ///< [in] the desired x size of the MFT
     166              :                int ny,                   ///< [in] the desired y size of the MFT
     167              :                dir ndir = dir::forward, /**< [in] [optional] direction of this MFT, either dir::forward
     168              :                                                               (default) or dir::backward */
     169              :                realT xOff = 0,           /**< [in] [optional] the x offset of the center of the
     170              :                                                               transformed array. Default 0.*/
     171              :                realT yOff = 0,           /**< [in] [optional] the y offset of the center of the
     172              :                                                               transformed array  Default 0.*/
     173              :                realT osFac = 1.0,        /**< [in] [optional] the oversampling factor.  Default 1. */
     174              :                typename std::enable_if<crank == 2>::type * = 0 );
     175              : 
     176              :     /// Conduct the MFT
     177              :     void operator()( eigenArrayOutT &out, /**< [out] the output of the DFT */
     178              :                      const eigenArrayInT &in /**< [in] the input to the DFT */ ) const;
     179              : };
     180              : 
     181              : template <typename inputT, typename outputT, size_t rank>
     182           12 : mftT<inputT, outputT, rank, 0>::mftT()
     183              : {
     184           12 : }
     185              : 
     186              : template <typename inputT, typename outputT, size_t rank>
     187              : template <size_t crank>
     188              : mftT<inputT, outputT, rank, 0>::mftT(
     189              :     int nx, int ny, dir ndir, realT xoff, realT yoff, realT osFac, typename std::enable_if<crank == 2>::type * )
     190              : {
     191              :     plan( nx, ny, ndir, xoff, yoff, osFac );
     192              : }
     193              : 
     194              : template <typename inputT, typename outputT, size_t rank>
     195              : template <size_t crank>
     196            0 : void mftT<inputT, outputT, rank, 0>::plan(
     197              :     int nx, int ny, dir ndir, realT xOff, realT yOff, realT osFac, typename std::enable_if<crank == 2>::type * )
     198              : {
     199            0 :     if(m_szX == nx && m_szY == ny && m_dir == ndir && m_xOff == xOff && m_yOff == yOff && m_osFac == osFac)
     200              :     {
     201            0 :         return;
     202              :     }
     203              : 
     204            0 :     m_szX = nx;
     205            0 :     m_szY = ny;
     206            0 :     m_dir = ndir;
     207            0 :     m_xOff = xOff;
     208            0 :     m_yOff = yOff;
     209            0 :     m_osFac = osFac;
     210              : 
     211            0 :     if( m_szX != m_szY )
     212              :     {
     213            0 :         throw std::invalid_argument( "MFT of non-square size is not implemented. nx must equal ny." );
     214              :     }
     215              : 
     216              :     // These should depend on szX too.
     217            0 :     m_dftR.resize( m_szX, m_szX );
     218            0 :     m_dftC.resize( m_szX, m_szX );
     219              : 
     220              :     // There is probably an osNx and an osNy?
     221            0 :     realT osN = m_szX * m_osFac;
     222              : 
     223            0 :     if( m_dir == dir::forward )
     224              :     {
     225            0 :         realT sign = -1;
     226            0 :         realT norm = 1.0 / ( m_szX * m_szY );
     227              : 
     228            0 :         for( int cc = 0; cc < m_szY; ++cc )
     229              :         {
     230            0 :             realT ccx = cc;
     231              : 
     232            0 :             for( int rr = 0; rr < m_szX; ++rr )
     233              :             {
     234            0 :                 realT x = ( rr - m_xOff ) * ccx / osN;
     235              : 
     236            0 :                 m_dftR( rr, cc ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
     237              : 
     238            0 :                 realT rrx = rr;
     239              : 
     240            0 :                 x = rrx * ( cc - m_yOff ) / osN;
     241              : 
     242            0 :                 m_dftC( rr, cc ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
     243              :             }
     244              :         }
     245              :     }
     246              :     else
     247              :     {
     248            0 :         realT sign = +1;
     249            0 :         realT norm = 1.0;
     250              : 
     251            0 :         for( int cc = 0; cc < m_szY; ++cc )
     252              :         {
     253            0 :             realT ccx = cc;
     254            0 :             if( ccx > m_szY / 2 )
     255            0 :                 ccx = -1 * ( m_szY - ccx );
     256              : 
     257            0 :             for( int rr = 0; rr < m_szX; ++rr )
     258              :             {
     259            0 :                 realT rrx = rr;
     260            0 :                 if( rrx > m_szX / 2 )
     261            0 :                     rrx = -1 * ( m_szX - rrx );
     262              : 
     263            0 :                 realT x = rrx * ( cc - m_xOff ) / osN;
     264              : 
     265            0 :                 m_dftR( cc, rr ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
     266              : 
     267            0 :                 x = ( rr - m_yOff ) * ccx / osN;
     268              : 
     269            0 :                 m_dftC( cc, rr ) = norm * exp( complexT( { 0, sign * 2 * pi<realT>() * x } ) );
     270              :             }
     271              :         }
     272              :     }
     273              : }
     274              : 
     275              : template <typename inputT, typename outputT, size_t rank>
     276            0 : void mftT<inputT, outputT, rank, 0>::operator()( eigenArrayOutT &out, const eigenArrayInT &in ) const
     277              : {
     278            0 :     out = ( m_dftR * in.matrix() * m_dftC ).array();
     279            0 : }
     280              : 
     281              : } // namespace ft
     282              : } // namespace math
     283              : } // namespace mx
     284              : 
     285              : #endif // mdft_hpp
        

Generated by: LCOV version 2.0-1