mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
templateCublas.hpp
Go to the documentation of this file.
1/** \file templateCublas.hpp
2 * \author Jared R. Males
3 * \brief A template interface to cuBlas
4 * \ingroup cuda_files
5 *
6 */
7
8//***********************************************************************//
9// Copyright 2019,2020 Jared R. Males (jaredmales@gmail.com)
10//
11// This file is part of mxlib.
12//
13// mxlib is free software: you can redistribute it and/or modify
14// it under the terms of the GNU General Public License as published by
15// the Free Software Foundation, either version 3 of the License, or
16// (at your option) any later version.
17//
18// mxlib is distributed in the hope that it will be useful,
19// but WITHOUT ANY WARRANTY; without even the implied warranty of
20// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21// GNU General Public License for more details.
22//
23// You should have received a copy of the GNU General Public License
24// along with mxlib. If not, see <http://www.gnu.org/licenses/>.
25//***********************************************************************//
26
27#ifndef math_templateCublas_hpp
28#define math_templateCublas_hpp
29
30#ifdef MXLIB_CUDA
31
32#include <cuda_runtime.h>
33#include <cublas_v2.h>
34
35namespace mx
36{
37namespace cuda
38{
39
40/// Multiplies a vector by a scalar, overwriting the vector with the result.
41/** Implements
42 * \f[
43 * \vec{x} = \alpha \vec{x}
44 * \f]
45 *
46 * Specializations are provided for float, double, complex-float, and complex-double
47 *
48 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
49 *
50 * \ingroup cublas
51 */
52template <typename floatT>
53cublasStatus_t cublasTscal( cublasHandle_t handle, ///< [in] The cublas context handle
54 int n, ///< [in] Number of elements in the vector
55 const floatT *alpha, ///< [in] The scalar
56 floatT *x, ///< [in.out] The vector of length n
57 int incx ///< [in] The stride of the vector
58);
59
60/// Multiplies a vector by a scalar, adding it to a second vector which is overwritten by the result.
61/** Implements
62 * \f[
63 * \vec{y} = \alpha \vec{x} + \vec{y}
64 * \f]
65 *
66 * Specializations are provided for float, double, complex-float, and complex-double
67 *
68 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
69 *
70 * \test Scenario: scaling and accumulating a vector with cublas \ref test_math_templateCublas_axpy "[test doc]"
71 *
72 * \ingroup cublas
73 */
74template <typename floatT>
75cublasStatus_t cublasTaxpy( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
76 int n, ///< [in] number of elements in the vector x and y
77 const floatT *alpha, ///< [in] scalar used for multiplication.
78 const floatT *x, ///< [in] vector with n elements.
79 int incx, ///< [in] stride between consecutive elements of x
80 floatT *y, ///< [in.out] vector with n elements.
81 int incy ///< [in] stride between consecutive elements of y
82);
83
84//----------------------------------------------------
85// Element-wise (Hadamard) products of vectors
86
87/// Calculates the element-wise product of two vectors, storing the result in the first.
88/** Calculates
89 * \f$
90 * x = x * y
91 * \f$
92 * element by element, a.k.a. the Hadamard product.
93 *
94 * Specializations are provided for:
95 * - float,float
96 * - complex-float, float
97 * - complex-float, complex-float
98 * - double, double
99 * - complex-double, double
100 * - complex-double, complex-double
101 *
102 * \ingroup cublas
103 */
104template <typename dataT1, typename dataT2>
105cudaError_t elementwiseXxY(
106 dataT1 *x, ///< [in.out] device pointer for the 1st vector. Is replaced with the product of the two vectors
107 dataT2 *y, ///< [in] device pointer for the 2nd vector.
108 int size ///< [in] the number of elements in the vectors.
109);
110
111/// Calculates the element-wise product of two vectors, storing the result in a third vector.
112/** Calculates
113 * \f$
114 * z = x * y
115 * \f$
116 * element by element, a.k.a. the Hadamard product.
117 *
118 * Specializations are provided for:
119 * - float, float,float
120 * - complex-float, complex-float, float
121 * - complex-float, complex-float, complex-float
122 * - double, double, double
123 * - complex-double, complex-double, double
124 * - complex-double, complex-double, complex-double
125 *
126 * \ingroup cublas
127 */
128template <typename dataT0, typename dataT1, typename dataT2>
129cudaError_t elementwiseXxY( dataT0 *z, /**< [out] device pointer for the result vector. Is filled in with the
130 product of the second two vectors*/
131 dataT1 *x, ///< [in] device pointer for the 1st vector.
132 dataT2 *y, ///< [in] device pointer for the 2nd vector.
133 int size ///< [in] the number of elements in the vectors.
134);
135
136/// Calculates the element-wise product of two vectors, accumulating the result in a third vector.
137/** Calculates
138 * \f$
139 * z += x * y
140 * \f$
141 * element by element, a.k.a. the Hadamard product of x and y.
142 *
143 * Specializations are provided for:
144 * - float, float,float
145 * - double, double, double
146 *
147 * \todo Complex overloads won't compile for some reason.
148 *
149 * \ingroup cublas
150 */
151template <typename dataT0, typename dataT1, typename dataT2>
152cudaError_t elementwiseXxYAccum( dataT0 *z, /**< [out] device pointer for the result vector. Is filled in with the
153 product of the second two vectors*/
154 dataT1 *x, ///< [in] device pointer for the 1st vector.
155 dataT2 *y, ///< [in] device pointer for the 2nd vector.
156 int size ///< [in] the number of elements in the vectors.
157);
158//----------------------------------------------------
159// Tgemv
160
161/// Perform a matrix-vector multiplication.
162/** Implements
163 * \f[
164 * \vec{y} = \alpha \mathbf{A} \vec{x} + \beta \vec{y}
165 * \f]
166 *
167 * Specializations are provided for float, double, complex-float, and complex-double
168 *
169 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
170 *
171 * Tests:
172 * - Multiplying a vectory by a matrix \ref test_math_templateCublas_cublasTgemv_inc "[code doc]"
173 *
174 * \ingroup cublas
175 */
176template <typename floatT>
177cublasStatus_t
178cublasTgemv( cublasHandle_t handle, /**< [in] handle to the cuBLAS library context. */
179 cublasOperation_t trans, /**< [in] operation on a, CUBLAS_OP_N for none, and CUBLAS_OP_T for transpose */
180 int m, /**< [in] [host] rows in matrix A. */
181 int n, /**< [in] [host] columns in matrix A. */
182 const floatT *alpha, /**< [in] [host/device] scalar used for multiplication of A */
183 const floatT *A, /**< [in] [device] vector of at least (1+(n-1)*abs(incx)) elements if
184 transa==CUBLAS_OP_N and at least (1+(m-1)*abs(incx))
185 elements otherwise. */
186 int lda, /**< [in] [host] leading dimension of A. lda must be at least max(1,m). */
187 const floatT *x, /**< [in] [device] vector of at least (1+(n-1)*abs(incx)) elements if
188 transa==CUBLAS_OP_N and at least (1+(m-1)*abs(incx))
189 elements otherwise. */
190 int incx, /**< [in] [host] stride of x. */
191 const floatT *beta, /**< [in] [host/device] scalar used for multiplication of y, if beta==0
192 then y does not need to be initialized.*/
193 floatT *y, /**< [in.out] [device] vector of at least (1+(m-1)*abs(incy)) elements
194 if transa==CUBLAS_OP_N and at
195 least (1+(n-1)*abs(incy)) elements otherwise.*/
196 int incy /**< [in] [host] stride of y */
197);
198
199/// Perform a matrix-vector multiplication for stride-less arrays
200/** Implements
201 * \f[
202 * \vec{y} = \alpha \mathbf{A} \vec{x} + \beta \vec{y}
203 * \f]
204 *
205 * Specializations are provided for float, double, complex-float, and complex-double
206 *
207 * \overload
208 * This version assumes stride is 1 in all arrays.
209 *
210 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
211 *
212 * \ingroup cublas
213 */
214template <typename floatT>
215cublasStatus_t
216cublasTgemv( cublasHandle_t handle, /**< [in] handle to the cuBLAS library context. */
217 cublasOperation_t trans, /**< [in] operation on a, CUBLAS_OP_N for none, and CUBLAS_OP_T for transpose */
218 int m, /**< [in] rows in matrix A. */
219 int n, /**< [in] columns in matrix A. */
220 const floatT *alpha, /**< [in] scalar used for multiplication of A */
221 const floatT *A, /**< [in] [device] vector of at least (1+(n-1)*abs(incx)) elements if
222 transa==CUBLAS_OP_N and at least (1+(m-1)*abs(incx))
223 elements otherwise. */
224 const floatT *x, /**< [in] [device] vector of at least (1+(n-1)*abs(incx)) elements if
225 transa==CUBLAS_OP_N and at least (1+(m-1)*abs(incx))
226 elements otherwise. */
227 const floatT *beta, /**< [in] [host/device] scalar used for multiplication of y, if beta==0
228 then y does not need to be initialized.*/
229 floatT *y /**< [in/out] [device] vector of at least (1+(m-1)*abs(incy)) elements
230 if transa==CUBLAS_OP_N and at
231 least (1+(n-1)*abs(incy)) elements otherwise.*/
232);
233
234template <>
235cublasStatus_t cublasTgemv<float>( cublasHandle_t handle,
236 cublasOperation_t trans,
237 int m,
238 int n,
239 const float *alpha,
240 const float *A,
241 int lda,
242 const float *x,
243 int incx,
244 const float *beta,
245 float *y,
246 int incy );
247
248template <>
249cublasStatus_t cublasTgemv<double>( cublasHandle_t handle,
250 cublasOperation_t trans,
251 int m,
252 int n,
253 const double *alpha,
254 const double *A,
255 int lda,
256 const double *x,
257 int incx,
258 const double *beta,
259 double *y,
260 int incy );
261
262template <>
263cublasStatus_t cublasTgemv<float>( cublasHandle_t handle,
264 cublasOperation_t trans,
265 int m,
266 int n,
267 const float *alpha,
268 const float *A,
269 const float *x,
270 const float *beta,
271 float *y );
272
273template <>
274cublasStatus_t cublasTgemv<double>( cublasHandle_t handle,
275 cublasOperation_t trans,
276 int m,
277 int n,
278 const double *alpha,
279 const double *A,
280 const double *x,
281 const double *beta,
282 double *y );
283
284} // namespace cuda
285} // namespace mx
286
287#endif // MXLIB_CUDA
288
289#endif // math_templateCublas_hpp
The mxlib c++ namespace.
Definition mxlib.hpp:37