8#ifndef wfp_fraunhoferPropagatorCuda_hpp
9#define wfp_fraunhoferPropagatorCuda_hpp
11#include "../math/constants.hpp"
14#include "../cuda/templateCuda.hpp"
15#include "../cuda/templateCudaPtr.hpp"
33template <
typename _wavefrontT>
44 typedef typename wavefrontT::Scalar::value_type
realT;
50 int wavefrontSizePixels{ 0 };
64 cufftHandle m_fftPlan{ 0 };
88 devicePtrT *complexPupil
99 devicePtrT *complexFocal
109template <
typename wavefrontT>
114template <
typename wavefrontT>
119 checkCudaErrors( cufftDestroy( m_fftPlan ) );
123template <
typename wavefrontT>
127 if( wfsPix == wavefrontSizePixels )
130 wavefrontSizePixels = wfsPix;
132 xcen = 0.5 * ( wfsPix - 1.0 );
133 ycen = 0.5 * ( wfsPix - 1.0 );
138 checkCudaErrors( cufftPlan2d( &m_fftPlan, wavefrontSizePixels, wavefrontSizePixels, CUFFT_C2C ) );
141template <
typename wavefrontT>
145 mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>(
146 (cuComplex *)complexPupil, (cuComplex *)m_centerFocal.m_devicePtr, wavefrontSizePixels * wavefrontSizePixels );
148 cufftExecC2C( m_fftPlan, (cufftComplex *)complexPupil, (cufftComplex *)complexFocal, CUFFT_FORWARD );
151template <
typename wavefrontT>
154 cufftExecC2C( m_fftPlan, (cufftComplex *)complexFocal, (cufftComplex *)complexPupil, CUFFT_INVERSE );
157 mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>(
158 (cuComplex *)complexPupil, (cuComplex *)m_centerPupil.m_devicePtr, wavefrontSizePixels * wavefrontSizePixels );
161template <
typename wavefrontT>
167 realT norm = 1. / ( wavefrontSizePixels * sqrt( 2 ) );
168 complexT cnorm = complexT( norm, norm );
171 complexT *centerFocal =
new complexT[wavefrontSizePixels * wavefrontSizePixels];
173 complexT *centerPupil =
new complexT[wavefrontSizePixels * wavefrontSizePixels];
176 realT arg = -2.0 *
pi * 0.5 * wavefrontSizePixels / ( wavefrontSizePixels - 1 );
178 for(
int ii = 0; ii < wavefrontSizePixels; ++ii )
180 for(
int jj = 0; jj < wavefrontSizePixels; ++jj )
182 centerFocal[ii * wavefrontSizePixels + jj] =
183 cnorm * exp( complexT( 0., arg * ( ( ii - xcen ) + ( jj - ycen ) ) ) );
184 centerPupil[ii * wavefrontSizePixels + jj] =
185 cnorm * exp( complexT( 0., 0.5 * pi - arg * ( ( ii - xcen ) + ( jj - ycen ) ) ) );
189 m_centerFocal.upload( centerFocal, wavefrontSizePixels * wavefrontSizePixels );
191 m_centerPupil.upload( centerPupil, wavefrontSizePixels * wavefrontSizePixels );
193 delete[] centerFocal;
194 delete[] centerPupil;
fraunhoferPropagator()
Constructor.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
~fraunhoferPropagator()
Destructor.
wavefrontT::Scalar::value_type realT
The real data type.
void propagatePupilToFocal(devicePtrT *complexFocal, devicePtrT *complexPupil)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(devicePtrT *complexPupil, devicePtrT *complexFocal)
Propagate the wavefront from Focal plane to Pupil plane.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
wavefrontT::Scalar complexT
The complex data type.
_wavefrontT wavefrontT
The wavefront data type.
mx::cuda::cudaPtr< complexT > m_centerFocal
Phase screen for tilting the pupil plane so that the focal plane image is centered [GPU memory].
mx::cuda::cudaPtr< complexT > m_centerPupil
Phase screen for un-tilting the pupil plane after propagating from a centered focal plane [GPU memory...
Class to perform Fraunhofer propagation between pupil and focal planes.
fraunhoferPropagator()
Constructor.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
~fraunhoferPropagator()
Destructor.
void propagatePupilToFocal(wavefrontT &complexFocal, wavefrontT &complexPupil, bool doCenter=true)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(wavefrontT &complexPupil, wavefrontT &complexFocal, bool doCenter=true)
Propagate the wavefront from Focal plane to Pupil plane.
Declares and defines a class for Fraunhofer propagation of optical wavefronts.
constexpr T pi()
Get the value of pi.
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
A smart-pointer wrapper for cuda device pointers.