27#ifndef psdFilterCuda_hpp 
   28#define psdFilterCuda_hpp 
   30#include "../cuda/templateCudaPtr.hpp" 
   31#include "../cuda/templateCuda.hpp" 
   38#include <helper_cuda.h> 
   69template <
typename _realT>
 
   70class psdFilter<_realT, 1>
 
   75    typedef Eigen::Array<realT, Eigen::Dynamic, Eigen::Dynamic> 
realArrayT; 
 
   76    typedef Eigen::Array<complexT, Eigen::Dynamic, Eigen::Dynamic>
 
   79    typedef realT deviceRealPtrT;
 
   91    cufftHandle m_fftPlan{ 0 };
 
 
  170template <
typename realT>
 
  171psdFilter<realT, 1>::psdFilter()
 
  175template <
typename realT>
 
  176psdFilter<realT, 1>::~psdFilter()
 
  180        checkCudaErrors( cufftDestroy( m_fftPlan ) );
 
  184template <
typename realT>
 
  185int psdFilter<realT, 1>::setSize()
 
  189    std::complex<realT> scale( 1. / ( m_rows * m_cols ), 0 );
 
  190    m_scale.upload( &scale, 1 );
 
  192    checkCudaErrors( cufftPlan2d( &m_fftPlan, m_rows, m_cols, CUFFT_C2C ) );
 
  197template <
typename realT>
 
  198int psdFilter<realT, 1>::rows()
 
  203template <
typename realT>
 
  204int psdFilter<realT, 1>::cols()
 
  225template <
typename realT>
 
  226int psdFilter<realT, 1>::psdSqrt( 
const realArrayT &npsdSqrt )
 
  228    m_rows = npsdSqrt.rows();
 
  229    m_cols = npsdSqrt.cols();
 
  231    complexArrayT tmp( m_rows, m_cols );
 
  233    tmp.real() = npsdSqrt;
 
  235    m_psdSqrt.upload( tmp.data(), m_rows * m_cols );
 
  242template <
typename realT>
 
  243int psdFilter<realT, 1>::psdSqrt( 
const cuda::cudaPtr<std::complex<realT>> &npsdSqrt, 
size_t rows, 
size_t cols )
 
  248    m_psdSqrt.resize( m_rows * m_cols );
 
  251    cudaMemcpy( m_psdSqrt.m_devicePtr,
 
  252                npsdSqrt.m_devicePtr,
 
  253                m_rows * m_cols * 
sizeof( std::complex<realT> ),
 
  254                cudaMemcpyDeviceToDevice );
 
  261template <
typename realT>
 
  262void psdFilter<realT, 1>::clear()
 
  268    m_psdSqrt.resize( 0 );
 
  271template <
typename realT>
 
  272int psdFilter<realT, 1>::filter( deviceComplexPtrT *noise )
 
  277    cufftExecC2C( m_fftPlan, (cuComplex *)noise, (cuComplex *)noise, CUFFT_FORWARD );
 
  280    mx::cuda::pointwiseMul<cuComplex>
 
  281        <<<32, 256>>>( (cuComplex *)noise, (cufftComplex *)m_psdSqrt.m_devicePtr, m_rows * m_cols );
 
  283    cufftExecC2C( m_fftPlan, (cuComplex *)noise, (cuComplex *)noise, CUFFT_INVERSE );
 
  285    mx::cuda::scalarMul<cuComplex>
 
  286        <<<32, 256>>>( (cuComplex *)noise, (cuComplex *)m_scale.m_devicePtr, m_rows * m_cols );
 
  294template <
typename realT>
 
  295int psdFilter<realT, 1>::operator()( realArrayT &noise )
 
  297    return filter( noise );
 
  300template <
typename realT>
 
  301int psdFilter<realT, 1>::operator()( realArrayT &noise, realArrayT &noiseIm )
 
  303    return filter( noise, &noiseIm );
 
int psdSqrt(const realArrayT &npsdSqrt)
Set the sqaure-root of the PSD.
_realT realT
Real floating point type.
int filter(deviceComplexPtrT *noise)
Apply the filter.
int cols()
Get the number of columns in the filter.
std::complex< _realT > complexT
Complex floating point type.
int psdSqrt(const cuda::cudaPtr< std::complex< realT > > &npsdSqrt, size_t rows, size_t cols)
Eigen::Array< realT, Eigen::Dynamic, Eigen::Dynamic > realArrayT
Eigen array type with Scalar==realT.
int operator()(realArrayT &noise, realArrayT &noiseIm)
Apply the filter.
mx::cuda::cudaPtr< complexT > m_scale
the scale factor.
mx::cuda::cudaPtr< complexT > m_psdSqrt
Pointer to the real array containing the square root of the PSD.
int setSize()
Set the size of the filter.
Eigen::Array< complexT, Eigen::Dynamic, Eigen::Dynamic > complexArrayT
Eigen array type with Scalar==complexT.
int operator()(realArrayT &noise)
Apply the filter.
void clear()
De-allocate all working memory and reset to initial state.
int rows()
Get the number of rows in the filter.
Declares and defines a class for filtering with PSDs.
A smart-pointer wrapper for cuda device pointers.