34#include "../mxError.hpp"
35#include "../math/ft/fftT.hpp"
36#include "../improc/eigenCube.hpp"
43namespace psdFilterTypes
47template <
typename realT,
size_t rank>
50template <
typename realT>
53 typedef std::vector<realT> realArrayT;
54 typedef std::vector<realT> *realArrayMapT;
55 typedef std::vector<std::complex<realT>> complexArrayT;
57 static void clear( realArrayT &arr )
62 static void clear( complexArrayT &arr )
68template <
typename realT>
69struct arrayT<realT, 2>
71 typedef Eigen::Array<realT, Eigen::Dynamic, Eigen::Dynamic> realArrayT;
72 typedef Eigen::Map<Eigen::Array<realT, Eigen::Dynamic, Eigen::Dynamic>> realArrayMapT;
74 typedef Eigen::Array<std::complex<realT>, Eigen::Dynamic, Eigen::Dynamic> complexArrayT;
76 static void clear( realArrayT &arr )
81 static void clear( complexArrayT &arr )
87template <
typename realT>
88struct arrayT<realT, 3>
90 typedef improc::eigenCube<realT> realArrayT;
91 typedef improc::eigenCube<realT> *realArrayMapT;
92 typedef improc::eigenCube<std::complex<realT>> complexArrayT;
94 static void clear( realArrayT &arr )
96 arr.resize( 0, 0, 0 );
99 static void clear( complexArrayT &arr )
101 arr.resize( 0, 0, 0 );
108template <
typename _realT,
size_t rank,
int cuda = 0>
143template <
typename _realT,
size_t _rank>
144class psdFilter<_realT, _rank, 0>
150 static const size_t rank = _rank;
169 bool m_owner{
false };
172 mutable complexArrayT
232 template <
size_t crank = rank>
233 int setSize(
typename std::enable_if<crank == 1>::type * = 0 );
245 template <
size_t crank = rank>
246 int setSize(
typename std::enable_if<crank == 2>::type * = 0 );
258 template <
size_t crank = rank>
259 int setSize(
typename std::enable_if<crank == 3>::type * = 0 );
304 template <
size_t crank = rank>
307 typename std::enable_if<crank == 1>::type * = 0 );
324 template <
size_t crank = rank>
328 typename std::enable_if<crank == 2>::type * = 0 );
345 template <
size_t crank = rank>
350 typename std::enable_if<crank == 3>::type * = 0 );
365 template <
size_t crank = rank>
368 typename std::enable_if<crank == 1>::type * = 0 );
383 template <
size_t crank = rank>
387 typename std::enable_if<crank == 2>::type * = 0 );
402 template <
size_t crank = rank>
407 typename std::enable_if<crank == 3>::type * = 0 );
422 template <
size_t crank = rank>
425 typename std::enable_if<crank == 1>::type * = 0 );
440 template <
size_t crank = rank>
444 typename std::enable_if<crank == 2>::type * = 0 );
459 template <
size_t crank = rank>
464 typename std::enable_if<crank == 3>::type * = 0 );
483 template <
size_t crank = rank>
487 typename std::enable_if<crank == 1>::type * = 0 )
const;
498 template <
size_t crank = rank>
502 typename std::enable_if<crank == 2>::type * = 0 )
const;
512 template <
size_t crank = rank>
516 typename std::enable_if<crank == 2>::type * = 0 )
const;
527 template <
size_t crank = rank>
531 typename std::enable_if<crank == 3>::type * = 0 )
const;
570template <
typename realT,
size_t rank>
571psdFilter<realT, rank>::psdFilter()
575template <
typename realT,
size_t rank>
576psdFilter<realT, rank>::~psdFilter()
578 if( m_psdSqrt && m_owner )
584template <
typename realT,
size_t rank>
585int psdFilter<realT, rank>::psdSqrt( realArrayT *npsdSqrt )
587 if( m_psdSqrt && m_owner )
592 m_psdSqrt = npsdSqrt;
600template <
typename realT,
size_t rank>
601int psdFilter<realT, rank>::psdSqrt(
const realArrayT &npsdSqrt )
603 if( m_psdSqrt && m_owner )
608 m_psdSqrt =
new realArrayT;
610 ( *m_psdSqrt ) = npsdSqrt;
618template <
typename realT,
size_t rank>
619template <
size_t crank>
620int psdFilter<realT, rank>::setSize(
typename std::enable_if<crank == 1>::type * )
624 mxError(
"psdFilter", MXE_PARAMNOTSET,
"m_psdSqrt has not been set yet, is still NULL." );
628 if( m_rows == m_psdSqrt->size() )
633 m_rows = m_psdSqrt->size();
637 m_ftWork.resize( m_rows );
646template <
typename realT,
size_t rank>
647template <
size_t crank>
648int psdFilter<realT, rank>::setSize(
typename std::enable_if<crank == 2>::type * )
652 mxError(
"psdFilter", MXE_PARAMNOTSET,
"m_psdSqrt has not been set yet, is still NULL." );
656 if( m_rows == m_psdSqrt->rows() && m_cols == m_psdSqrt->cols() )
661 m_rows = m_psdSqrt->rows();
662 m_cols = m_psdSqrt->cols();
665 m_ftWork.resize( m_rows, m_cols );
674template <
typename realT,
size_t rank>
675template <
size_t crank>
676int psdFilter<realT, rank>::setSize(
typename std::enable_if<crank == 3>::type * )
680 mxError(
"psdFilter", MXE_PARAMNOTSET,
"m_psdSqrt has not been set yet, is still NULL." );
684 if( m_rows == m_psdSqrt->rows() && m_cols == m_psdSqrt->cols() && m_planes == m_psdSqrt->planes() )
689 m_rows = m_psdSqrt->rows();
690 m_cols = m_psdSqrt->cols();
691 m_planes = m_psdSqrt->planes();
693 m_ftWork.resize( m_rows, m_cols, m_planes );
702template <
typename realT,
size_t rank>
703int psdFilter<realT, rank>::rows()
708template <
typename realT,
size_t rank>
709int psdFilter<realT, rank>::cols()
714template <
typename realT,
size_t rank>
715int psdFilter<realT, rank>::planes()
720template <
typename realT,
size_t rank>
721template <
size_t crank>
722int psdFilter<realT, rank>::psdSqrt( realArrayT *npsdSqrt, realT df,
typename std::enable_if<crank == 1>::type * )
725 return psdSqrt( npsdSqrt );
728template <
typename realT,
size_t rank>
729template <
size_t crank>
730int psdFilter<realT, rank>::psdSqrt( realArrayT *npsdSqrt,
733 typename std::enable_if<crank == 2>::type * )
737 return psdSqrt( npsdSqrt );
740template <
typename realT,
size_t rank>
741template <
size_t crank>
742int psdFilter<realT, rank>::psdSqrt(
743 realArrayT *npsdSqrt, realT dk1, realT dk2, realT df,
typename std::enable_if<crank == 3>::type * )
748 return psdSqrt( npsdSqrt );
751template <
typename realT,
size_t rank>
752template <
size_t crank>
753int psdFilter<realT, rank>::psdSqrt(
const realArrayT &npsdSqrt, realT df,
typename std::enable_if<crank == 1>::type * )
756 return psdSqrt( npsdSqrt );
759template <
typename realT,
size_t rank>
760template <
size_t crank>
761int psdFilter<realT, rank>::psdSqrt(
const realArrayT &npsdSqrt,
764 typename std::enable_if<crank == 2>::type * )
768 return psdSqrt( npsdSqrt );
771template <
typename realT,
size_t rank>
772template <
size_t crank>
773int psdFilter<realT, rank>::psdSqrt(
774 const realArrayT &npsdSqrt, realT dk1, realT dk2, realT df,
typename std::enable_if<crank == 3>::type * )
779 return psdSqrt( npsdSqrt );
782template <
typename realT,
size_t rank>
783template <
size_t crank>
784int psdFilter<realT, rank>::psd(
const realArrayT &npsd,
const realT df1,
typename std::enable_if<crank == 1>::type * )
786 if( m_psdSqrt && m_owner )
791 m_psdSqrt =
new realArrayT;
794 m_psdSqrt->resize( npsd.size() );
795 for(
size_t n = 0; n < npsd.size(); ++n )
796 ( *m_psdSqrt )[n] = sqrt( npsd[n] );
807template <
typename realT,
size_t rank>
808template <
size_t crank>
809int psdFilter<realT, rank>::psd(
const realArrayT &npsd,
812 typename std::enable_if<crank == 2>::type * )
814 if( m_psdSqrt && m_owner )
819 m_psdSqrt =
new realArrayT;
821 ( *m_psdSqrt ) = npsd.sqrt();
832template <
typename realT,
size_t rank>
833template <
size_t crank>
834int psdFilter<realT, rank>::psd(
const realArrayT &npsd,
838 typename std::enable_if<crank == 3>::type * )
840 if( m_psdSqrt && m_owner )
845 m_psdSqrt =
new realArrayT;
848 m_psdSqrt->resize( npsd.rows(), npsd.cols(), npsd.planes() );
849 for(
int pp = 0; pp < npsd.planes(); ++pp )
851 for(
int cc = 0; cc < npsd.cols(); ++cc )
853 for(
int rr = 0; rr < npsd.rows(); ++rr )
855 m_psdSqrt->image( pp )( rr, cc ) = sqrt( npsd.image( pp )( rr, cc ) );
871template <
typename realT,
size_t rank>
872void psdFilter<realT, rank>::clear()
875 psdFilterTypes::arrayT<realT, rank>::clear( m_ftWork );
881 if( m_psdSqrt && m_owner )
888template <
typename realT,
size_t rank>
889template <
size_t crank>
890int psdFilter<realT, rank>::filter( realArrayT &noise,
892 typename std::enable_if<crank == 1>::type * )
const
894 for(
int nn = 0; nn < noise.size(); ++nn )
895 m_ftWork[nn] = complexT( noise[nn], 0 );
898 m_fft_fwd( m_ftWork.data(), m_ftWork.data() );
901 for(
int nn = 0; nn < m_ftWork.size(); ++nn )
902 m_ftWork[nn] *= ( *m_psdSqrt )[nn];
904 m_fft_bwd( m_ftWork.data(), m_ftWork.data() );
907 realT norm = sqrt( noise.size() / m_dFreq1 );
908 for(
int nn = 0; nn < m_ftWork.size(); ++nn )
909 noise[nn] = m_ftWork[nn].real() / norm;
911 if( noiseIm !=
nullptr )
913 for(
int nn = 0; nn < m_ftWork.size(); ++nn )
914 ( *noiseIm )[nn] = m_ftWork[nn].imag() / norm;
920template <
typename realT,
size_t rank>
921template <
size_t crank>
922int psdFilter<realT, rank>::filter( realArrayT &noise,
924 typename std::enable_if<crank == 2>::type * )
const
927 for(
int ii = 0; ii < noise.rows(); ++ii )
929 for(
int jj = 0; jj < noise.cols(); ++jj )
931 m_ftWork( ii, jj ) = complexT( noise( ii, jj ), 0 );
936 m_fft_fwd( m_ftWork.data(), m_ftWork.data() );
939 m_ftWork *= *m_psdSqrt;
941 m_fft_bwd( m_ftWork.data(), m_ftWork.data() );
943 realT norm = sqrt( noise.rows() * noise.cols() / ( m_dFreq1 * m_dFreq2 ) );
946 noise = m_ftWork.real() / norm;
948 if( noiseIm !=
nullptr )
950 *noiseIm = m_ftWork.imag() / norm;
956template <
typename realT,
size_t rank>
957template <
size_t crank>
958int psdFilter<realT, rank>::filter( realArrayMapT noise,
960 typename std::enable_if<crank == 2>::type * )
const
963 for(
int ii = 0; ii < noise.rows(); ++ii )
965 for(
int jj = 0; jj < noise.cols(); ++jj )
967 m_ftWork( ii, jj ) = complexT( noise( ii, jj ), 0 );
972 m_fft_fwd( m_ftWork.data(), m_ftWork.data() );
975 m_ftWork *= *m_psdSqrt;
977 m_fft_bwd( m_ftWork.data(), m_ftWork.data() );
979 realT norm = sqrt( noise.rows() * noise.cols() / ( m_dFreq1 * m_dFreq2 ) );
982 noise = m_ftWork.real() / norm;
984 if( noiseIm !=
nullptr )
986 *noiseIm = m_ftWork.imag() / norm;
992template <
typename realT,
size_t rank>
993template <
size_t crank>
994int psdFilter<realT, rank>::filter( realArrayT &noise,
996 typename std::enable_if<crank == 3>::type * )
const
999 for(
int pp = 0; pp < noise.planes(); ++pp )
1001 for(
int ii = 0; ii < noise.rows(); ++ii )
1003 for(
int jj = 0; jj < noise.cols(); ++jj )
1005 m_ftWork.image( pp )( ii, jj ) = complexT( noise.image( pp )( ii, jj ), 0 );
1011 m_fft_fwd( m_ftWork.data(), m_ftWork.data() );
1014 for(
int pp = 0; pp < noise.planes(); ++pp )
1015 m_ftWork.image( pp ) *= m_psdSqrt->image( pp );
1017 m_fft_bwd( m_ftWork.data(), m_ftWork.data() );
1021 realT norm = sqrt( m_rows * m_cols * m_planes / ( m_dFreq1 * m_dFreq2 * m_dFreq3 ) );
1022 for(
int pp = 0; pp < noise.planes(); ++pp )
1023 noise.image( pp ) = m_ftWork.image( pp ).real() / norm;
1025 if( noiseIm !=
nullptr )
1027 for(
int pp = 0; pp < noise.planes(); ++pp )
1028 noiseIm->image( pp ) = m_ftWork.image( pp ).imag() / norm;
1034template <
typename realT,
size_t rank>
1035int psdFilter<realT, rank>::operator()( realArrayT &noise )
const
1037 return filter( noise );
1040template <
typename realT,
size_t rank>
1041int psdFilter<realT, rank>::operator()( realArrayMapT noise )
const
1043 return filter( noise );
1046template <
typename realT,
size_t rank>
1047int psdFilter<realT, rank>::operator()( realArrayT &noise, realArrayT &noiseIm )
const
1049 return filter( noise, &noiseIm );
int setSize(typename std::enable_if< crank==1 >::type *=0)
Set the size of the filter.
int operator()(realArrayT &noise, realArrayT &noiseIm) const
Apply the filter.
complexArrayT m_ftWork
Working memory for the FFT. Declared mutable so it can be accessed in the const filter method.
int rows()
Get the number of rows in the filter.
int psd(const realArrayT &npsd, const realT df, typename std::enable_if< crank==1 >::type *=0)
Set the sqaure-root of the PSD from the PSD.
int filter(realArrayMapT noise, realArrayT *noiseIm=nullptr, typename std::enable_if< crank==2 >::type *=0) const
Apply the filter.
psdFilterTypes::arrayT< realT, rank >::complexArrayT complexArrayT
std::vector for rank==1, Eigen::Array for rank==2, eigenCube for rank==3.
int psdSqrt(realArrayT *npsdSqrt, realT df, typename std::enable_if< crank==1 >::type *=0)
math::ft::fftT< complexT, complexT, rank, 0 > m_fft_fwd
FFT object for the forward transform.
int filter(realArrayT &noise, realArrayT *noiseIm=nullptr, typename std::enable_if< crank==3 >::type *=0) const
Apply the filter.
psdFilterTypes::arrayT< realT, rank >::realArrayMapT realArrayMapT
std::vector for rank==1, Eigen::Map for rank==2, eigenCube for rank==3.
int filter(realArrayT &noise, realArrayT *noiseIm=nullptr, typename std::enable_if< crank==1 >::type *=0) const
Apply the filter.
int psdSqrt(realArrayT *npsdSqrt, realT dk1, realT dk2, typename std::enable_if< crank==2 >::type *=0)
int setSize(typename std::enable_if< crank==2 >::type *=0)
Set the size of the filter.
int operator()(realArrayT &noise) const
Apply the filter.
int psd(const realArrayT &npsd, const realT dk1, const realT dk2, typename std::enable_if< crank==2 >::type *=0)
Set the sqaure-root of the PSD from the PSD.
int psdSqrt(realArrayT *npsdSqrt, realT dk1, realT dk2, realT df, typename std::enable_if< crank==3 >::type *=0)
int cols()
Get the number of columns in the filter.
std::complex< _realT > complexT
Complex floating point type.
psdFilterTypes::arrayT< realT, rank >::realArrayT realArrayT
std::vector for rank==1, Eigen::Array for rank==2, eigenCube for rank==3.
int psdSqrt(const realArrayT &npsdSqrt, realT dk1, realT dk2, realT df, typename std::enable_if< crank==3 >::type *=0)
Set the sqaure-root of the PSD.
math::ft::fftT< complexT, complexT, rank, 0 > m_fft_bwd
FFT object for the backward transfsorm.
int setSize(typename std::enable_if< crank==3 >::type *=0)
Set the size of the filter.
void clear()
De-allocate all working memory and reset to initial state.
int filter(realArrayT &noise, realArrayT *noiseIm=nullptr, typename std::enable_if< crank==2 >::type *=0) const
Apply the filter.
_realT realT
Real floating point type.
int operator()(realArrayMapT noise) const
Apply the filter.
int psdSqrt(const realArrayT &npsdSqrt, realT df, typename std::enable_if< crank==1 >::type *=0)
Set the sqaure-root of the PSD.
int psd(const realArrayT &npsd, const realT dk1, const realT dk2, const realT df, typename std::enable_if< crank==3 >::type *=0)
Set the sqaure-root of the PSD from the PSD.
int psdSqrt(const realArrayT &npsdSqrt)
Set the sqaure-root of the PSD.
int psdSqrt(const realArrayT &npsdSqrt, realT dk1, realT dk2, typename std::enable_if< crank==2 >::type *=0)
Set the sqaure-root of the PSD.
int planes()
Get the number of planes in the filter.
int psdSqrt(realArrayT *npsdSqrt)
@ backward
Specifies the backward transform.
@ forward
Specifies the forward transform.
Types for different ranks in psdFilter.