mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
fraunhoferPropagatorCuda.hpp
Go to the documentation of this file.
1 /** \file fraunhoferPropagatorCuda.hpp
2  * \brief Declares and defines a class for Fraunhofer propagation of optical wavefronts with Cuda
3  * \ingroup imaging
4  * \author Jared R. Males (jaredmales@gmail.com)
5  *
6  */
7 
8 
9 #ifndef wfp_fraunhoferPropagatorCuda_hpp
10 #define wfp_fraunhoferPropagatorCuda_hpp
11 
12 #include "../math/constants.hpp"
13 
14 #include "fraunhoferPropagator.hpp"
15 #include "../cuda/templateCuda.hpp"
16 #include "../cuda/templateCudaPtr.hpp"
17 
18 namespace mx
19 {
20 namespace wfp
21 {
22 
23 /// Class to perform Fraunhofer propagation between pupil and focal planes using a GPU
24 /** This class uses the FFT to propagate between planes, and normalizes so that flux
25  * is conserved. For propagation from pupil to focal plane, the pupil wavefront is tilted so that
26  * the focal-plane image is centered at the geometric center of the array. After propagation from
27  * focal plane to pupil plane, the pupil plane wavefront is un-tilted to restore the
28  * pupil to its original position.
29  *
30  * \tparam _wavefrontT is an Eigen::Array-like type, with std::complex values.
31  *
32  * \ingroup imaging
33  */
34 template<typename _wavefrontT>
35 class fraunhoferPropagator<_wavefrontT,1>
36 {
37 public:
38 
39  ///The wavefront data type
40  typedef _wavefrontT wavefrontT;
41 
42  ///The complex data type
43  typedef typename wavefrontT::Scalar complexT;
44 
45  ///The real data type
46  typedef typename wavefrontT::Scalar::value_type realT;
47 
48  typedef complexT devicePtrT;
49 
50 protected:
51 
52  ///The size of the wavefront in pixels
53  int wavefrontSizePixels {0};
54 
55  realT xcen {0}; ///<x-coordinate of focal plane center, in pixels
56  realT ycen {0}; ///<x-coordinate of focal plane center, in pixels
57 
58  ///Phase screen for tilting the pupil plane so that the focal plane image is centered [GPU memory].
59  //complexT * m_centerFocal {nullptr};
61 
62  ///Phase screen for un-tilting the pupil plane after propagating from a centered focal plane [GPU memory].
63  //complexT * m_centerPupil {nullptr};
65 
66  ///Cuda FFT plan. We only need one since the forward/inverse is part of execution.
67  cufftHandle m_fftPlan {0};
68 
69 public:
70  ///Constructor
72 
73  ///Destructor
75 
76  ///Set the size of the wavefront, in pixels
77  /** Checks if the size changes, does nothing if no change. Otherwise, calls
78  * \ref makeShiftPhase to pre-calculate the tilt arrays and plans the FFTs.
79  *
80  */
81  void setWavefrontSizePixels( int wfsPix /**< [in] the desired new size of the wavefront */ );
82 
83  ///Propagate the wavefront from the pupil plane to the focal plane
84  /** The pupil plane wavefront (complexPupil) is multiplied by a tilt to place the
85  * image in the geometric center of the focal plane.
86  *
87  */
88  void propagatePupilToFocal( devicePtrT * complexFocal, ///< [out] the focal plane wavefront. Must be pre-allocated to same size as complexPupil.
89  devicePtrT * complexPupil ///< [in] the pupil plane wavefront. Modified due to application of centering tilt.
90  );
91 
92  ///Propagate the wavefront from Focal plane to Pupil plane
93  /**
94  * After the fourier transform, the output pupil plane wavefront is de-tilted, restoring it
95  * to the state prior to calling \ref propagatePupilToFocal
96  *
97  */
98  void propagateFocalToPupil( devicePtrT * complexPupil, ///< [out] the pupil plane wavefront. Must be pre-allocated to same size as complexFocal.
99  devicePtrT * complexFocal ///< [in] the focal plane wavefront.
100  );
101 
102 
103 protected:
104 
105  ///Calculate the complex tilt arrays for centering and normalizing the wavefronts
106  /**
107  */
109 
110 };
111 
112 template<typename wavefrontT>
114 {
115 }
116 
117 template<typename wavefrontT>
119 {
120  if( m_fftPlan )
121  {
122  checkCudaErrors(cufftDestroy(m_fftPlan));
123  }
124 }
125 
126 template<typename wavefrontT>
128 {
129  //If no change in size, do nothing
130  if(wfsPix == wavefrontSizePixels) return;
131 
132  wavefrontSizePixels = wfsPix;
133 
134  xcen = 0.5*(wfsPix - 1.0);
135  ycen = 0.5*(wfsPix - 1.0);
136 
137  makeShiftPhase();
138 
139  //Plan the FFT
140  checkCudaErrors(cufftPlan2d(&m_fftPlan, wavefrontSizePixels, wavefrontSizePixels, CUFFT_C2C));
141 
142 }
143 
144 template<typename wavefrontT>
145 void fraunhoferPropagator<wavefrontT,1>::propagatePupilToFocal( devicePtrT * complexFocal,
146  devicePtrT * complexPupil
147  )
148 {
149  //Apply the centering shift -- this adjusts by 0.5 pixels and normalizes
150  mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>( (cuComplex *) complexPupil, (cuComplex *) m_centerFocal.m_devicePtr, wavefrontSizePixels*wavefrontSizePixels);
151 
152  cufftExecC2C(m_fftPlan, (cufftComplex *) complexPupil, (cufftComplex *) complexFocal, CUFFT_FORWARD);
153 
154 }
155 
156 template<typename wavefrontT>
157 void fraunhoferPropagator<wavefrontT, 1>::propagateFocalToPupil( devicePtrT * complexPupil,
158  devicePtrT * complexFocal
159  )
160 {
161  cufftExecC2C(m_fftPlan, (cufftComplex *) complexFocal, (cufftComplex *) complexPupil, CUFFT_INVERSE);
162 
163  //Unshift the wavefront and normalize
164  mx::cuda::pointwiseMul<cuComplex><<<32, 256>>>( (cuComplex*) complexPupil, (cuComplex*) m_centerPupil.m_devicePtr, wavefrontSizePixels*wavefrontSizePixels);
165 }
166 
167 template<typename wavefrontT>
169 {
170  constexpr realT pi = math::pi<realT>();
171 
172  //The normalization is included in the tilt.
173  realT norm = 1./(wavefrontSizePixels*sqrt(2));
174  complexT cnorm = complexT(norm, norm);
175 
176  ///Host memory to build the shift screens
177  complexT * centerFocal = new complexT[wavefrontSizePixels*wavefrontSizePixels];
178 
179  complexT * centerPupil = new complexT[wavefrontSizePixels*wavefrontSizePixels];
180 
181  //Shift by 0.5 pixels
182  realT arg = -2.0*pi*0.5*wavefrontSizePixels/(wavefrontSizePixels-1);
183 
184  for(int ii=0; ii < wavefrontSizePixels; ++ii)
185  {
186  for(int jj=0; jj < wavefrontSizePixels; ++jj)
187  {
188  centerFocal[ii*wavefrontSizePixels + jj] = cnorm*exp(complexT(0.,arg*((ii-xcen)+(jj-ycen))));
189  centerPupil[ii*wavefrontSizePixels + jj] = cnorm*exp(complexT(0., 0.5*pi - arg*((ii-xcen)+(jj-ycen))));
190  }
191  }
192 
193  m_centerFocal.upload(centerFocal, wavefrontSizePixels*wavefrontSizePixels);
194 
195  m_centerPupil.upload(centerPupil, wavefrontSizePixels*wavefrontSizePixels);
196 
197  delete[] centerFocal;
198  delete[] centerPupil;
199 
200 }
201 
202 
203 } //namespace wfp
204 } //namespace mx
205 
206 #endif //wfp_fraunhoferPropagatorCuda_hpp
207 
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
wavefrontT::Scalar::value_type realT
The real data type.
void propagatePupilToFocal(devicePtrT *complexFocal, devicePtrT *complexPupil)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(devicePtrT *complexPupil, devicePtrT *complexFocal)
Propagate the wavefront from Focal plane to Pupil plane.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
wavefrontT::Scalar complexT
The complex data type.
mx::cuda::cudaPtr< complexT > m_centerFocal
Phase screen for tilting the pupil plane so that the focal plane image is centered [GPU memory].
mx::cuda::cudaPtr< complexT > m_centerPupil
Phase screen for un-tilting the pupil plane after propagating from a centered focal plane [GPU memory...
Class to perform Fraunhofer propagation between pupil and focal planes.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
void propagatePupilToFocal(wavefrontT &complexFocal, wavefrontT &complexPupil, bool doCenter=true)
Propagate the wavefront from the pupil plane to the focal plane.
void propagateFocalToPupil(wavefrontT &complexPupil, wavefrontT &complexFocal, bool doCenter=true)
Propagate the wavefront from Focal plane to Pupil plane.
Declares and defines a class for Fraunhofer propagation of optical wavefronts.
constexpr T pi()
Get the value of pi.
Definition: constants.hpp:52
The mxlib c++ namespace.
Definition: mxError.hpp:107