27#ifndef wfp_fraunhoferPropagator_hpp
28#define wfp_fraunhoferPropagator_hpp
30#include "../mxlib.hpp"
32#include "../math/constants.hpp"
36#include "../math/ft/fftT.hpp"
39#include "../math/cuda/templateCuda.hpp"
40#include "../math/cuda/cudaPtr.hpp"
41#include "../math/cuda/templateCufft.hpp"
42#include "../math/cuda/templateCublas.hpp"
46#define BREAD_CRUMB std::cout << "DEBUG: " << __FILE__ << " " << __LINE__ << "\n";
57template <
typename wavefrontT,
int cudaGPU>
58struct fraunhoferPropagatorArrayT;
60template <
typename wavefrontT>
61struct fraunhoferPropagatorArrayT<wavefrontT, 0>
63 typedef wavefrontT arrayT;
67template <
typename wavefrontT>
68struct fraunhoferPropagatorArrayT<wavefrontT, 1>
70 typedef mx::cuda::cudaPtr<typename wavefrontT::Scalar> arrayT;
87template <
typename _wavefrontT,
int _cudaGPU = 0>
95 static constexpr int cudaGPU = _cudaGPU;
101 typedef typename wavefrontT::Scalar::value_type
realT;
104 typedef typename fraunhoferPropagatorArrayT<wavefrontT, cudaGPU>::arrayT
arrayT;
127 math::ft::fftT<complexT, complexT, 2, cudaGPU>
m_fft_fwd;
158 template <
int ccudaGPU = cudaGPU>
160 typename std::enable_if<ccudaGPU == 0>::type * = 0 );
167 template <
int ccudaGPU = cudaGPU>
169 typename std::enable_if<ccudaGPU == 1>::type * = 0 );
175 template <
int ccudaGPU = cudaGPU>
177 typename std::enable_if<ccudaGPU == 0>::type * = 0 );
183 template <
int ccudaGPU = cudaGPU>
185 typename std::enable_if<ccudaGPU == 1>::type * = 0 );
223 template <
int ccudaGPU = cudaGPU>
227 template <
int ccudaGPU = cudaGPU>
238template <
typename wavefrontT,
int cudaGPU>
243template <
typename wavefrontT,
int cudaGPU>
248template <
typename wavefrontT,
int cudaGPU>
254template <
typename wavefrontT,
int cudaGPU>
260 if( m_wavefrontSizePixels > 0 )
266template <
typename wavefrontT,
int cudaGPU>
267template <
int ccudaGPU>
269 typename std::enable_if<ccudaGPU == 0>::type * )
271 complexPupil *= m_centerFocal;
276template <
typename wavefrontT,
int cudaGPU>
277template <
int ccudaGPU>
279 typename std::enable_if<ccudaGPU == 1>::type * )
284 cudaError_t ce = mx::cuda::elementwiseXxY(
reinterpret_cast<cuComplex *
>( complexPupil.m_devicePtr ),
285 reinterpret_cast<cuComplex *
>( m_centerFocal.m_devicePtr ),
286 pow( m_wavefrontSizePixels, 2 ) );
288 if( ce != cudaSuccess )
290 return internal::mxlib_error_report<verboseT>( cudaError2error_t( ce ),
"from mx::cuda::elementwiseXxY" );
300template <
typename wavefrontT,
int cudaGPU>
301template <
int ccudaGPU>
303 typename std::enable_if<ccudaGPU == 0>::type * )
305 complexPupil *= m_centerPupil;
310template <
typename wavefrontT,
int cudaGPU>
311template <
int ccudaGPU>
313 typename std::enable_if<ccudaGPU == 1>::type * )
316 cudaError_t ce = mx::cuda::elementwiseXxY(
reinterpret_cast<cuComplex *
>( complexPupil.m_devicePtr ),
317 reinterpret_cast<cuComplex *
>( m_centerPupil.m_devicePtr ),
318 pow( m_wavefrontSizePixels, 2 ) );
320 if( ce != cudaSuccess )
322 return internal::mxlib_error_report<verboseT>( cudaError2error_t( ce ),
"from mx::cuda::elementwiseXxY" );
329template <
typename wavefrontT,
int cudaGPU>
339 setWavefrontSizePixels( complexPupil.rows() );
346 error_t ec = shiftPupil( complexPupil );
350 return internal::mxlib_error_report<verboseT>(ec,
"from shiftPupil");
357 m_fft_fwd( complexFocal.data(), complexPupil.data() );
364template <
typename wavefrontT,
int cudaGPU>
373 setWavefrontSizePixels( complexPupil.rows() );
377 m_fft_back( complexPupil.data(), complexFocal.data() );
384 error_t ec = unshiftPupil( complexPupil );
388 return internal::mxlib_error_report<verboseT>(ec,
"from unshiftPupil");
398template <
typename wavefrontT,
int cudaGPU>
404 if( wfsPix == m_centerFocal.rows() )
411 m_wavefrontSizePixels = wfsPix;
415 m_xcen = 0.5 * ( wfsPix - 1.0 );
416 m_ycen = 0.5 * ( wfsPix - 1.0 );
424 m_fft_fwd.plan( wfsPix, wfsPix );
433template <
typename wavefrontT,
int cudaGPU>
434template <
int ccudaGPU>
436 complexT *centerPupil,
437 typename std::enable_if<ccudaGPU == 0>::type * )
439 m_centerFocal.resize( m_wavefrontSizePixels, m_wavefrontSizePixels );
440 m_centerPupil.resize( m_wavefrontSizePixels, m_wavefrontSizePixels );
443 for(
int cc = 0; cc < m_wavefrontSizePixels; ++cc )
445 for(
int rr = 0; rr < m_wavefrontSizePixels; ++rr )
447 m_centerFocal( rr, cc ) = centerFocal[cc * m_wavefrontSizePixels + rr];
448 m_centerPupil( rr, cc ) = centerPupil[cc * m_wavefrontSizePixels + rr];
453template <
typename wavefrontT,
int cudaGPU>
454template <
int ccudaGPU>
455void fraunhoferPropagator<wavefrontT, cudaGPU>::setShiftPhase( complexT *centerFocal,
456 complexT *centerPupil,
457 typename std::enable_if<ccudaGPU == 1>::type * )
463 m_centerFocal.upload( centerFocal, m_wavefrontSizePixels, m_wavefrontSizePixels );
467 m_centerPupil.upload( centerPupil, m_wavefrontSizePixels, m_wavefrontSizePixels );
474template <
typename wavefrontT,
int cudaGPU>
480 realT norm = 1. / ( m_wavefrontSizePixels * sqrt( 2 ) );
486 complexT *centerFocal =
new complexT[m_wavefrontSizePixels * m_wavefrontSizePixels];
488 complexT *centerPupil =
new complexT[m_wavefrontSizePixels * m_wavefrontSizePixels];
491 realT arg = -2.0 * pi * 0.5 * ( m_wavefrontSizePixels - m_wholePixel ) / ( m_wavefrontSizePixels - 1 );
493 for(
int cc = 0; cc < m_wavefrontSizePixels; ++cc )
495 for(
int rr = 0; rr < m_wavefrontSizePixels; ++rr )
497 centerFocal[cc * m_wavefrontSizePixels + rr] =
498 cnorm * exp(
complexT( 0., arg * ( ( rr - m_xcen ) + ( cc - m_ycen ) ) ) );
499 centerPupil[cc * m_wavefrontSizePixels + rr] =
500 cnorm * exp(
complexT( 0., 0.5 * pi - arg * ( ( rr - m_xcen ) + ( cc - m_ycen ) ) ) );
506 setShiftPhase( centerFocal, centerPupil );
510 delete[] centerFocal;
511 delete[] centerPupil;
Class to perform Fraunhofer propagation between pupil and focal planes.
fraunhoferPropagatorArrayT< wavefrontT, cudaGPU >::arrayT arrayT
The array data type.
fraunhoferPropagatorArrayT< wavefrontT, cudaGPU >::arrayT m_centerFocal
Phase screen for tilting the pupil plane so that the focal plane image is centered.
math::ft::fftT< complexT, complexT, 2, cudaGPU > m_fft_back
FFT object for backward FFTs.
wavefrontT::Scalar complexT
The complex data type.
void setWavefrontSizePixels(int wfsPix)
Set the size of the wavefront, in pixels.
error_t propagateFocalToPupil(arrayT &complexPupil, arrayT &complexFocal, bool doCenter=true)
Propagate the wavefront from Focal plane to Pupil plane.
int wholePixel()
Get the value of the wholePixel parameter.
fraunhoferPropagatorArrayT< wavefrontT, cudaGPU >::arrayT m_centerPupil
Phase screen for un-tilting the pupil plane after propagating from a centered focal plane.
error_t shiftPupil(arrayT &complexPupil, typename std::enable_if< ccudaGPU==0 >::type *=0)
Apply the shift to a pupil wavefront and apply the normalization.
void wholePixel(realT wp)
Set the value of the wholePixel parameter.
realT m_wholePixel
Determines how the image is centered.
wavefrontT::Scalar::value_type realT
The real data type.
~fraunhoferPropagator()
Destructor.
_wavefrontT wavefrontT
The wavefront data type.
error_t unshiftPupil(arrayT &complexPupil, typename std::enable_if< ccudaGPU==1 >::type *=0)
Apply the shift to a pupil wavefront which will restore it to a centered pupil image,...
int m_wavefrontSizePixels
The size of the wavefront in pixels.
realT m_xcen
x-coordinate of focal plane center, in pixels
fraunhoferPropagator()
Constructor.
realT m_ycen
x-coordinate of focal plane center, in pixels
error_t propagatePupilToFocal(arrayT &complexFocal, arrayT &complexPupil, bool doCenter=true)
Propagate the wavefront from the pupil plane to the focal plane.
void makeShiftPhase()
Calculate the complex tilt arrays for centering and normalizing the wavefronts.
error_t unshiftPupil(arrayT &complexPupil, typename std::enable_if< ccudaGPU==0 >::type *=0)
Apply the shift to a pupil wavefront which will restore it to a centered pupil image,...
math::ft::fftT< complexT, complexT, 2, cudaGPU > m_fft_fwd
FFT object for forward FFTs.
error_t shiftPupil(arrayT &complexPupil, typename std::enable_if< ccudaGPU==1 >::type *=0)
Apply the shift to a pupil wavefront and apply the normalization.
error_t
The mxlib error codes.
@ noerror
No error has occurred.
@ backward
Specifies the backward transform.
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
Declares and defines a class for managing images.
Utilities for modeling image formation.
MXLIB_DEFAULT_VERBOSITY d
The default verbosity.