mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
fraunhoferPropagatorCuda.hpp
Go to the documentation of this file.
1/** \file fraunhoferPropagatorCuda.hpp
2 * \brief Declares and defines a class for Fraunhofer propagation of optical wavefronts with Cuda
3 * \ingroup imaging
4 * \author Jared R. Males (jaredmales@gmail.com)
5 *
6 */
7
8#ifndef wfp_fraunhoferPropagatorCuda_hpp
9#define wfp_fraunhoferPropagatorCuda_hpp
10
11#include "../math/constants.hpp"
12
14#include "../cuda/templateCuda.hpp"
15#include "../cuda/templateCudaPtr.hpp"
16
17namespace mx
18{
19namespace wfp
20{
21
22/// Class to perform Fraunhofer propagation between pupil and focal planes using a GPU
23/** This class uses the FFT to propagate between planes, and normalizes so that flux
24 * is conserved. For propagation from pupil to focal plane, the pupil wavefront is tilted so that
25 * the focal-plane image is centered at the geometric center of the array. After propagation from
26 * focal plane to pupil plane, the pupil plane wavefront is un-tilted to restore the
27 * pupil to its original position.
28 *
29 * \tparam _wavefrontT is an Eigen::Array-like type, with std::complex values.
30 *
31 * \ingroup imaging
32 */
33template <typename _wavefrontT>
34class fraunhoferPropagator<_wavefrontT, 1>
35{
36 public:
37 /// The wavefront data type
38 typedef _wavefrontT wavefrontT;
39
40 /// The complex data type
41 typedef typename wavefrontT::Scalar complexT;
42
43 /// The real data type
44 typedef typename wavefrontT::Scalar::value_type realT;
45
46 typedef complexT devicePtrT;
47
48 protected:
49 /// The size of the wavefront in pixels
50 int wavefrontSizePixels{ 0 };
51
52 realT xcen{ 0 }; ///< x-coordinate of focal plane center, in pixels
53 realT ycen{ 0 }; ///< x-coordinate of focal plane center, in pixels
54
55 /// Phase screen for tilting the pupil plane so that the focal plane image is centered [GPU memory].
56 // complexT * m_centerFocal {nullptr};
58
59 /// Phase screen for un-tilting the pupil plane after propagating from a centered focal plane [GPU memory].
60 // complexT * m_centerPupil {nullptr};
62
63 /// Cuda FFT plan. We only need one since the forward/inverse is part of execution.
64 cufftHandle m_fftPlan{ 0 };
65
66 public:
67 /// Constructor
69
70 /// Destructor
72
73 /// Set the size of the wavefront, in pixels
74 /** Checks if the size changes, does nothing if no change. Otherwise, calls
75 * \ref makeShiftPhase to pre-calculate the tilt arrays and plans the FFTs.
76 *
77 */
78 void setWavefrontSizePixels( int wfsPix /**< [in] the desired new size of the wavefront */ );
79
80 /// Propagate the wavefront from the pupil plane to the focal plane
81 /** The pupil plane wavefront (complexPupil) is multiplied by a tilt to place the
82 * image in the geometric center of the focal plane.
83 *
84 */
86 devicePtrT
87 *complexFocal, ///< [out] the focal plane wavefront. Must be pre-allocated to same size as complexPupil.
88 devicePtrT *complexPupil ///< [in] the pupil plane wavefront. Modified due to application of centering tilt.
89 );
90
91 /// Propagate the wavefront from Focal plane to Pupil plane
92 /**
93 * After the fourier transform, the output pupil plane wavefront is de-tilted, restoring it
94 * to the state prior to calling \ref propagatePupilToFocal
95 *
96 */
97 void propagateFocalToPupil( devicePtrT *complexPupil, ///< [out] the pupil plane wavefront. Must be pre-allocated to
98 ///< same size as complexFocal.
99 devicePtrT *complexFocal ///< [in] the focal plane wavefront.
100 );
101
102 protected:
103 /// Calculate the complex tilt arrays for centering and normalizing the wavefronts
104 /**
105 */
107};
108
109template <typename wavefrontT>
111{
112}
113
114template <typename wavefrontT>
116{
117 if( m_fftPlan )
118 {
119 checkCudaErrors( cufftDestroy( m_fftPlan ) );
120 }
121}
122
123template <typename wavefrontT>
125{
126 // If no change in size, do nothing
127 if( wfsPix == wavefrontSizePixels )
128 return;
129
130 wavefrontSizePixels = wfsPix;
131
132 xcen = 0.5 * ( wfsPix - 1.0 );
133 ycen = 0.5 * ( wfsPix - 1.0 );
134
135 makeShiftPhase();
136
137 // Plan the FFT
138 checkCudaErrors( cufftPlan2d( &m_fftPlan, wavefrontSizePixels, wavefrontSizePixels, CUFFT_C2C ) );
139}
140
141template <typename wavefrontT>
142void fraunhoferPropagator<wavefrontT, 1>::propagatePupilToFocal( devicePtrT *complexFocal, devicePtrT *complexPupil )
143{
144 // Apply the centering shift -- this adjusts by 0.5 pixels and normalizes
145 mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>(
146 (cuComplex *)complexPupil, (cuComplex *)m_centerFocal.m_devicePtr, wavefrontSizePixels * wavefrontSizePixels );
147
148 cufftExecC2C( m_fftPlan, (cufftComplex *)complexPupil, (cufftComplex *)complexFocal, CUFFT_FORWARD );
149}
150
151template <typename wavefrontT>
152void fraunhoferPropagator<wavefrontT, 1>::propagateFocalToPupil( devicePtrT *complexPupil, devicePtrT *complexFocal )
153{
154 cufftExecC2C( m_fftPlan, (cufftComplex *)complexFocal, (cufftComplex *)complexPupil, CUFFT_INVERSE );
155
156 // Unshift the wavefront and normalize
157 mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>(
158 (cuComplex *)complexPupil, (cuComplex *)m_centerPupil.m_devicePtr, wavefrontSizePixels * wavefrontSizePixels );
159}
160
161template <typename wavefrontT>
163{
164 constexpr realT pi = math::pi<realT>();
165
166 // The normalization is included in the tilt.
167 realT norm = 1. / ( wavefrontSizePixels * sqrt( 2 ) );
168 complexT cnorm = complexT( norm, norm );
169
170 /// Host memory to build the shift screens
171 complexT *centerFocal = new complexT[wavefrontSizePixels * wavefrontSizePixels];
172
173 complexT *centerPupil = new complexT[wavefrontSizePixels * wavefrontSizePixels];
174
175 // Shift by 0.5 pixels
176 realT arg = -2.0 * pi * 0.5 * wavefrontSizePixels / ( wavefrontSizePixels - 1 );
177
178 for( int ii = 0; ii < wavefrontSizePixels; ++ii )
179 {
180 for( int jj = 0; jj < wavefrontSizePixels; ++jj )
181 {
182 centerFocal[ii * wavefrontSizePixels + jj] =
183 cnorm * exp( complexT( 0., arg * ( ( ii - xcen ) + ( jj - ycen ) ) ) );
184 centerPupil[ii * wavefrontSizePixels + jj] =
185 cnorm * exp( complexT( 0., 0.5 * pi - arg * ( ( ii - xcen ) + ( jj - ycen ) ) ) );
186 }
187 }
188
189 m_centerFocal.upload( centerFocal, wavefrontSizePixels * wavefrontSizePixels );
190
191 m_centerPupil.upload( centerPupil, wavefrontSizePixels * wavefrontSizePixels );
192
193 delete[] centerFocal;
194 delete[] centerPupil;
195}
196
197} // namespace wfp
198} // namespace mx
199
200#endif // wfp_fraunhoferPropagatorCuda_hpp
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
wavefrontT::Scalar::value_type realT
The real data type.
void propagatePupilToFocal(devicePtrT *complexFocal, devicePtrT *complexPupil)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(devicePtrT *complexPupil, devicePtrT *complexFocal)
Propagate the wavefront from Focal plane to Pupil plane.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
wavefrontT::Scalar complexT
The complex data type.
mx::cuda::cudaPtr< complexT > m_centerFocal
Phase screen for tilting the pupil plane so that the focal plane image is centered [GPU memory].
mx::cuda::cudaPtr< complexT > m_centerPupil
Phase screen for un-tilting the pupil plane after propagating from a centered focal plane [GPU memory...
Class to perform Fraunhofer propagation between pupil and focal planes.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
void propagatePupilToFocal(wavefrontT &complexFocal, wavefrontT &complexPupil, bool doCenter=true)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(wavefrontT &complexPupil, wavefrontT &complexFocal, bool doCenter=true)
Propagate the wavefront from Focal plane to Pupil plane.
Declares and defines a class for Fraunhofer propagation of optical wavefronts.
constexpr T pi()
Get the value of pi.
Definition constants.hpp:55
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
The mxlib c++ namespace.
Definition mxError.hpp:106
A smart-pointer wrapper for cuda device pointers.
Definition cudaPtr.hpp:47