9 #ifndef wfp_fraunhoferPropagatorCuda_hpp
10 #define wfp_fraunhoferPropagatorCuda_hpp
12 #include "../math/constants.hpp"
15 #include "../cuda/templateCuda.hpp"
16 #include "../cuda/templateCudaPtr.hpp"
34 template<
typename _wavefrontT>
46 typedef typename wavefrontT::Scalar::value_type
realT;
53 int wavefrontSizePixels {0};
67 cufftHandle m_fftPlan {0};
89 devicePtrT * complexPupil
99 devicePtrT * complexFocal
112 template<
typename wavefrontT>
117 template<
typename wavefrontT>
122 checkCudaErrors(cufftDestroy(m_fftPlan));
126 template<
typename wavefrontT>
130 if(wfsPix == wavefrontSizePixels)
return;
132 wavefrontSizePixels = wfsPix;
134 xcen = 0.5*(wfsPix - 1.0);
135 ycen = 0.5*(wfsPix - 1.0);
140 checkCudaErrors(cufftPlan2d(&m_fftPlan, wavefrontSizePixels, wavefrontSizePixels, CUFFT_C2C));
144 template<
typename wavefrontT>
146 devicePtrT * complexPupil
150 mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>( (cuComplex *) complexPupil, (cuComplex *) m_centerFocal.m_devicePtr, wavefrontSizePixels*wavefrontSizePixels);
152 cufftExecC2C(m_fftPlan, (cufftComplex *) complexPupil, (cufftComplex *) complexFocal, CUFFT_FORWARD);
156 template<
typename wavefrontT>
158 devicePtrT * complexFocal
161 cufftExecC2C(m_fftPlan, (cufftComplex *) complexFocal, (cufftComplex *) complexPupil, CUFFT_INVERSE);
164 mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>( (cuComplex*) complexPupil, (cuComplex*) m_centerPupil.m_devicePtr, wavefrontSizePixels*wavefrontSizePixels);
167 template<
typename wavefrontT>
170 constexpr realT
pi = math::pi<realT>();
173 realT norm = 1./(wavefrontSizePixels*sqrt(2));
174 complexT cnorm = complexT(norm, norm);
177 complexT * centerFocal =
new complexT[wavefrontSizePixels*wavefrontSizePixels];
179 complexT * centerPupil =
new complexT[wavefrontSizePixels*wavefrontSizePixels];
182 realT arg = -2.0*
pi*0.5*wavefrontSizePixels/(wavefrontSizePixels-1);
184 for(
int ii=0; ii < wavefrontSizePixels; ++ii)
186 for(
int jj=0; jj < wavefrontSizePixels; ++jj)
188 centerFocal[ii*wavefrontSizePixels + jj] = cnorm*exp(complexT(0.,arg*((ii-xcen)+(jj-ycen))));
189 centerPupil[ii*wavefrontSizePixels + jj] = cnorm*exp(complexT(0., 0.5*
pi - arg*((ii-xcen)+(jj-ycen))));
193 m_centerFocal.upload(centerFocal, wavefrontSizePixels*wavefrontSizePixels);
195 m_centerPupil.upload(centerPupil, wavefrontSizePixels*wavefrontSizePixels);
197 delete[] centerFocal;
198 delete[] centerPupil;
fraunhoferPropagator()
Constructor.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
~fraunhoferPropagator()
Destructor.
wavefrontT::Scalar::value_type realT
The real data type.
void propagatePupilToFocal(devicePtrT *complexFocal, devicePtrT *complexPupil)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(devicePtrT *complexPupil, devicePtrT *complexFocal)
Propagate the wavefront from Focal plane to Pupil plane.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
wavefrontT::Scalar complexT
The complex data type.
_wavefrontT wavefrontT
The wavefront data type.
mx::cuda::cudaPtr< complexT > m_centerFocal
Phase screen for tilting the pupil plane so that the focal plane image is centered [GPU memory].
mx::cuda::cudaPtr< complexT > m_centerPupil
Phase screen for un-tilting the pupil plane after propagating from a centered focal plane [GPU memory...
Class to perform Fraunhofer propagation between pupil and focal planes.
fraunhoferPropagator()
Constructor.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
~fraunhoferPropagator()
Destructor.
void propagatePupilToFocal(wavefrontT &complexFocal, wavefrontT &complexPupil, bool doCenter=true)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(wavefrontT &complexPupil, wavefrontT &complexFocal, bool doCenter=true)
Propagate the wavefront from Focal plane to Pupil plane.
Declares and defines a class for Fraunhofer propagation of optical wavefronts.
constexpr T pi()
Get the value of pi.