27#ifndef psdFilterCuda_hpp
28#define psdFilterCuda_hpp
30#include "../cuda/templateCudaPtr.hpp"
31#include "../cuda/templateCuda.hpp"
38#include <helper_cuda.h>
69template <
typename _realT>
70class psdFilter<_realT, 1>
75 typedef Eigen::Array<realT, Eigen::Dynamic, Eigen::Dynamic>
realArrayT;
76 typedef Eigen::Array<complexT, Eigen::Dynamic, Eigen::Dynamic>
79 typedef realT deviceRealPtrT;
91 cufftHandle m_fftPlan{ 0 };
170template <
typename realT>
171psdFilter<realT, 1>::psdFilter()
175template <
typename realT>
176psdFilter<realT, 1>::~psdFilter()
180 checkCudaErrors( cufftDestroy( m_fftPlan ) );
184template <
typename realT>
185int psdFilter<realT, 1>::setSize()
189 std::complex<realT> scale( 1. / ( m_rows * m_cols ), 0 );
190 m_scale.upload( &scale, 1 );
192 checkCudaErrors( cufftPlan2d( &m_fftPlan, m_rows, m_cols, CUFFT_C2C ) );
197template <
typename realT>
198int psdFilter<realT, 1>::rows()
203template <
typename realT>
204int psdFilter<realT, 1>::cols()
225template <
typename realT>
226int psdFilter<realT, 1>::psdSqrt(
const realArrayT &npsdSqrt )
228 m_rows = npsdSqrt.rows();
229 m_cols = npsdSqrt.cols();
231 complexArrayT tmp( m_rows, m_cols );
233 tmp.real() = npsdSqrt;
235 m_psdSqrt.upload( tmp.data(), m_rows * m_cols );
242template <
typename realT>
243int psdFilter<realT, 1>::psdSqrt(
const cuda::cudaPtr<std::complex<realT>> &npsdSqrt,
size_t rows,
size_t cols )
248 m_psdSqrt.resize( m_rows * m_cols );
251 cudaMemcpy( m_psdSqrt.m_devicePtr,
252 npsdSqrt.m_devicePtr,
253 m_rows * m_cols *
sizeof( std::complex<realT> ),
254 cudaMemcpyDeviceToDevice );
261template <
typename realT>
262void psdFilter<realT, 1>::clear()
268 m_psdSqrt.resize( 0 );
271template <
typename realT>
272int psdFilter<realT, 1>::filter( deviceComplexPtrT *noise )
277 cufftExecC2C( m_fftPlan, (cuComplex *)noise, (cuComplex *)noise, CUFFT_FORWARD );
280 mx::cuda::pointwiseMul<cuComplex>
281 <<<32, 256>>>( (cuComplex *)noise, (cufftComplex *)m_psdSqrt.m_devicePtr, m_rows * m_cols );
283 cufftExecC2C( m_fftPlan, (cuComplex *)noise, (cuComplex *)noise, CUFFT_INVERSE );
285 mx::cuda::scalarMul<cuComplex>
286 <<<32, 256>>>( (cuComplex *)noise, (cuComplex *)m_scale.m_devicePtr, m_rows * m_cols );
294template <
typename realT>
295int psdFilter<realT, 1>::operator()( realArrayT &noise )
297 return filter( noise );
300template <
typename realT>
301int psdFilter<realT, 1>::operator()( realArrayT &noise, realArrayT &noiseIm )
303 return filter( noise, &noiseIm );
int psdSqrt(const realArrayT &npsdSqrt)
Set the sqaure-root of the PSD.
_realT realT
Real floating point type.
int filter(deviceComplexPtrT *noise)
Apply the filter.
int cols()
Get the number of columns in the filter.
std::complex< _realT > complexT
Complex floating point type.
int psdSqrt(const cuda::cudaPtr< std::complex< realT > > &npsdSqrt, size_t rows, size_t cols)
Eigen::Array< realT, Eigen::Dynamic, Eigen::Dynamic > realArrayT
Eigen array type with Scalar==realT.
int operator()(realArrayT &noise, realArrayT &noiseIm)
Apply the filter.
mx::cuda::cudaPtr< complexT > m_scale
the scale factor.
mx::cuda::cudaPtr< complexT > m_psdSqrt
Pointer to the real array containing the square root of the PSD.
int setSize()
Set the size of the filter.
Eigen::Array< complexT, Eigen::Dynamic, Eigen::Dynamic > complexArrayT
Eigen array type with Scalar==complexT.
int operator()(realArrayT &noise)
Apply the filter.
void clear()
De-allocate all working memory and reset to initial state.
int rows()
Get the number of rows in the filter.
Declares and defines a class for filtering with PSDs.
A smart-pointer wrapper for cuda device pointers.